import scipy as SP
import scipy.stats as ST
from parts2014_gfpvar.tools.common import *
from parts2014_gfpvar.tools.util import *
from parts2014_gfpvar.tools.io import *
from fcs import read_fcs


TRAIT_NAMES = ["mean(log10(GFP))", "median(log10(GFP))", "var(log10(GFP))", "mean(GFP)", "median(GFP)", "var(GFP)"]
TRAIT_DEVIATION_LIMITS = {}
for t in TRAIT_NAMES: TRAIT_DEVIATION_LIMITS[t] = 0.25*(1 + 2.*t.count("var")) # allow more lenience for variance
GOOD, BAD, MISSING, COUNT = "OK", "Bad", "NA", "Count"


def get_overlaps(lists):
    loc = []
    for l in lists:
        loc.append({})
        for i in range(len(l)):
            loc[-1][l[i]] = i
    overlap = set(lists[0])
    for l in lists[1:]: overlap = overlap & set(l)
    overlap = list(overlap)
    return [[loc[k][o] for o in overlap] for k in range(len(lists))]



""" Read tab-delimited cytometry output produced by quantify.py create_screen_matrix
@param filename path to tab-delimited file
@param includes_linear whether linear measures are included in input file
@param normalise whether to have all plates have same median values
@param return_header if True, return one extra list, the file headers
@param log_linear whether to log-transform linear measurements
@return meta,slice,parentals,orfs,plates - metadata per measurement, measurements in small FSC slice, all parentals, all orfs, all plates
"""
def read_cyto_tab(filename, includes_linear=True, includes_all=True, includes_flag=False, normalise=True, return_header=False, log_linear=True, return_all=False):
    ifh = file(filename, 'r')
    header = ifh.next().strip().split("\t")
    meta, slice, all = [], [], []
    meta_range, slice_range, all_range = range(0,7+includes_flag), includes_flag+SP.arange(7,11), includes_flag+SP.arange(11,15)
    if includes_linear: slice_range, all_range = includes_flag+SP.arange(7,14), includes_flag+SP.arange(14,21)
    if not includes_all: all_range = []
        
    for l in ifh:
        d = SP.array(l.strip().split())
        if d[2].count("AB") + d[2].count("CD") > 0: # plate from R1
            d[2] = d[2][0:-2]
        meta.append(d[meta_range])        
        slice.append(d[slice_range])
        all.append(d[all_range])
    meta, slice, all = SP.array(meta), SP.array(slice, float), SP.array(all, float)

    parentals, orfs, plates = sorted(SP.unique(meta[:,5])), sorted(SP.unique(meta[:,4])), sorted(SP.unique(meta[:,2]))
    if log_linear:
        slice[:,3:6] = SP.log10(slice[:,3:6] + 0.1) # avoid log 0
	if return_all: all[:,3:6] = SP.log10(all[:,3:6] + 0.1) 

    overall_median = SP.median(slice, axis=0)
    for pl in plates: # normalise each plate and each parent to have same median value
        I = [SP.where((meta[:,2] == pl) & (meta[:,5] == p))[0] for p in parentals]
        if normalise:
            for Ip in I: slice[Ip] = overall_median + slice[Ip] - SP.median(slice[Ip], axis=0)
	    if return_all:
                for Ip in I: all[Ip] = overall_median + all[Ip] - SP.median(all[Ip], axis=0)

    # This is a mess due to backward compatibility:
    if return_header:
	if return_all:  return meta, slice, parentals, orfs, plates, header, all
        return meta, slice, parentals, orfs, plates, header
    if return_all: return meta, slice, parentals, orfs, plates, all
    return meta, slice, parentals, orfs, plates
    

""" Given two vectors, see which points are off linear fit
@param v1 vector 1
@param v2 vector 2
@param good_threshold maximum error term in regression to be considered good
@return list of good point indexes, list of bad point indexes, list of all errors, float slope, float intercept (last two for transforming v1 to get v2)
"""
def is_good(v1, v2, good_threshold=0.25):
    # Find out which data points are discordant between replicates. For this:
    vt, slope, intercept = align_vals(v1, v2, good_threshold, True) # 1. transform v1 to be aligned with v2
    error = vt - v2 # 2. calculate error as expected - observed values
    Igood = (abs(error) < good_threshold) # 3. good points have small error
    return SP.where(Igood)[0], SP.where(~Igood)[0], error, slope, intercept


""" Given two vectors, perform a linear transform on first to minimize MSE to second
@param v1 vector 1
@param v2 vector 2
@param good_threshold maximum error term in regression to be considered good. The points outside this range are not taken into account when optimizing fit
@param return_stats if True, return slope, intercept
@return a*v1 + b, such that ||a*v1 + b - v2|| is minimized. If return_stats, also return slope, intercept of transform as tuple's second and third terms.
"""
def align_vals(v1, v2, good_threshold=0.25, return_stats=False):
    I = SP.where(~SP.isnan(v1) & ~SP.isnan(v2))[0]
    slope, intercept, rv, pv, se = ST.linregress(v1[I],v2[I]) # 1. calculate regression line
    error = (intercept + v1*slope) - v2 # 2. calculate error as expected - observed values
    Igood = SP.where(abs(error) < good_threshold)[0] # 3. filter out grossly wrong datapoints
    if len(Igood) == 0:
        LOG.debug("Weird - *no* fit at all for v1 (len %d) and v2 (len %d) with cutoff %.2f"%(len(v1), len(v2), good_threshold))
        return SP.array([])
    slope, intercept, rv, pv, se = ST.linregress(v1[Igood], v2[Igood]) # 4. Refit regression line on the good points...
    if return_stats: return v1*slope+intercept, slope, intercept
    return v1*slope+intercept



def calculate_parental_counts(parentals, orfs, meta, slice, min_cells=250, trait=1, normalise=True, return_used_orfs=False):
    parent_orf_vals = {}
    orf_parental_count = {}
    parental_orf_count, used_orfs = {},{}
    overall_mean = slice[:,trait].mean()
    
    for p in parentals: 
        parental_orf_count[p] = {}
        parent_orf_vals[p], used_orfs[p] = {},[]
        for o in orfs: 
            if o not in orf_parental_count: orf_parental_count[o] = 0
            parental_orf_count[p][o] = sum((meta[:,4] == o) & (meta[:,5] == p))
            parent_orf_vals[p][o] = SP.nan
            I = SP.where((meta[:,4] == o) & (meta[:,5] == p))[0]
            if len(I) > 0:
                if slice[I[0], 6] < min_cells: continue
                used_orfs[p].append(o)
                orf_parental_count[o] += 1
                parent_orf_vals[p][o] = slice[I[0],trait]
        if normalise:
            all_vals = SP.array(parent_orf_vals[p].values()) # array of orfs x plates
            baseline_val = sorted(all_vals[SP.where(~SP.isnan(all_vals))[0]])[10] # baseline from actual values, where
            for o in orfs: parent_orf_vals[p][o] = overall_mean + parent_orf_vals[p][o] - baseline_val  # normalise each liquid plate to have same median for each par

    if return_used_orfs: return parental_orf_count, parent_orf_vals, orf_parental_count, used_orfs
    return parental_orf_count, parent_orf_vals, orf_parental_count



def calculate_parental_counts_across_plates(parentals, orfs, plates, meta, slice, min_cells=200, normalise=True, trait=3, all_plates_good_only=True):
    parent_orf_count, parent_orf_vals, orf_parent_count, plate_orf_count = {}, {}, {}, {}
    overall_mean = slice[:,trait].mean()

    for p in parentals: 
        parent_orf_count[p] = {}
        parent_orf_vals[p] = {}
        for o in orfs: 
            if o not in orf_parent_count: orf_parent_count[o] = {}
            if p not in orf_parent_count[o]: orf_parent_count[o][p] = 0
                
            parent_orf_count[p][o] = sum((meta[:,4] == o) & (meta[:,5] == p))
            parent_orf_vals[p][o] = SP.zeros(2)*SP.nan
            
            for pl,plate in enumerate(sorted(plates)):
                I = SP.where((meta[:,4] == o) & (meta[:,5] == p) & (meta[:,2] == plate))[0] # all wells that match the orf, parent and plate
                if len(I) > 0 and slice[I[0],6] > min_cells: # and have plenty of observed cells
                    orf_parent_count[o][p] += 1
                    parent_orf_vals[p][o][pl] = slice[I[0],trait]
        if normalise:
            all_vals = SP.array(parent_orf_vals[p].values()) # array of orfs x plates
            Igood = range(len(all_vals))
            if all_plates_good_only: Igood = SP.where(~(SP.isnan(all_vals).any(axis=1)))[0] # both plates ok
            for o in orfs: parent_orf_vals[p][o] = overall_mean + parent_orf_vals[p][o] - ST.nanmedian(all_vals[Igood], axis=0)  # normalise each liquid plate to have same median for each parent
        
    for pl in plates:
        plate_orf_count[pl] = {}
        for o in orfs:
            plate_orf_count[pl][o] = sum((meta[:,4] == o) & (meta[:,2] == pl))

    return parent_orf_count, parent_orf_vals, orf_parent_count, plate_orf_count




""" Get the correct parent for this screen and cluster from the long cluster name
@param clnm string cluster name in form '[WELL PARENTS] | [BUDDING] | [RFP] | [FSC-A}', e.g. 'BY/RM | Unbudded | RFP low | FSCA in 0.00 100000000.00'
@param screen name of the screen (string
@param plate plate in the screen for old bug fixing
@return None, if for this screen and this cluster, the parent does not make sense (e.g. PARENT is BY alone, and RFP is low). Otherwise, appropriate parent from WELL PARENTS """
def get_parent(clnm, screen, single_slice_cluster=False, plate=None):
    parent = None
    #if (screen == "haploid-R1") and (plate[0:3] in ["P10", "P11"]):
    #    if clnm.count("BY") > 0: clnm = clnm.replace("low", "high") # fix old bug - all clusters for haploid-R1 P10 and P11 are marked as GFP low
    rfp = ["high", "low"][(1 + clnm.count("RFP low") - clnm.count("RFP high"))/2] # is it high or low RFP? ought to be only "high" for BY and "low" for RM
    if (screen[0:9] == "haploid-R") and (clnm.count("combined") == 0):
        if (clnm.count("RFP low") + clnm.count("BY") > 1) or (clnm.count("RFP high") + clnm.count("RM") > 1):
            return None

    if screen in ["targets", "swaps", "haploid-R5_RMswaps", "all-h2", "tst", "Hap1"]:
        parents = clnm.split(" ")[0].split("/")
        if screen == "Hap1":
            return parents[0]
        elif (screen == "haploid-R5_RMswaps") and (parents[0] == "RM"):
            if rfp == "low": return "RM"
        elif len(parents) == 1 and single_slice_cluster: # if single well, just match high gfp to BY and low to RM; other clusters are not to be processed
            return parents[0]
        elif len(parents) > 1:
            return parents[rfp == "low"]
        else:
            return None
    if screen in ["gfpswaps"]:
        p = clnm.split(" ")[0]
        if p[0:2] == "BY" and rfp == "high": return p
        elif p[0:2] == "RM" and rfp == "low": return p
        return None

    mixing = ['BY', 'RM', "BYRM"][clnm.count("RM") + clnm.count("BY_RM") + clnm.count("BY/RM")] # identify the cluster from its name - is it one parent separately, or mixed
    if ((rfp == "high") and (mixing in ["BY","BYRM"])) or ((rfp == "low") and (mixing in ["RM", "BYRM"])): # if RFP and mixing type match
        parent = ['BY','RM'][rfp == "low"] # get which parent - RM if RFP is low
    return parent




#meta, slice, parentals, orfs, plates, header = read_cyto_tab("%s/summary/cyto/Pilot_screen_BYxRM/paper_haploid.tab"%DATA_DIR, normalise=True, return_header=True)
#meta[SP.where(meta[:,2] == "P4CD")[0], 2] = "Bad plate"
#for i in range(len(meta)):
#    if meta[i,2].count("AB") + meta[i,2].count("CD") > 0:
#        meta[i,2] = meta[i,2][0:-2]


def get_figure_1_data(min_cells=500, trait=3, screen="haploid-R1", align=False):
    values, used_orfs = {}, {}
    
    for p in ["P%d"%i for i in range(1,12)]: # for each plate
        vals, uorfs = [], []
        I1 = SP.where((meta[:,1] == screen) & (meta[:,5] == "BY") & (meta[:,2] == p))[0] # Filter the two replicate screens to values from this plate
        I2 = SP.where((meta[:,1] == screen) & (meta[:,5] == "RM") & (meta[:,2] == p))[0]
        if len(I1)*len(I2) == 0: continue
        
        for o in orfs:
            i1 = SP.where(meta[I1,4] == o)[0]
            i2 = SP.where(meta[I2,4] == o)[0]
            if len(i1)*len(i2) != 1: continue 
            i1, i2 = i1[0], i2[0] # Get the indices corresponding to the replicate measurements of each ORF. If not present in exactly one copy, skip ORF.
            
            if slice[I1[i1],6] < min_cells or slice[I2[i2],6] < min_cells: continue # Do not consider ORFs that don't have enough cells observed
            vals.append([slice[I1[i1], trait], slice[I2[i2], trait]]) # Trait 3 = mean GFP; trait 5 = var(gfp)
            uorfs.append(o)
        values[p], used_orfs[p] = SP.array(vals), SP.array(uorfs)
        values[p] = values[p] - SP.median(values[p], axis=0) + values[p].mean()
        if align:
            values[p] = align_vals(values[p])
    return values, used_orfs



def get_figure_S2_data(min_cells=500, trait=3, parent="RM"):
    shortplate = SP.array([p[0:-2] for p in meta[:,2]]) # Replicate 1 plates have extra designation "AB" or "CD", which is cropped here
    values, used_orfs = {}, {}
    
    for p in ["P%d"%i for i in range(1,12)]: # for each plate
        vals, uorfs = [], []
        I1 = SP.where(((meta[:,1] == "haploid-R1") | (meta[:,1] == "haploid-R6")) & (meta[:,5] == parent) & (meta[:,2] == p))[0] # Filter the two replicate screens to values from this plate
        I2 = SP.where((meta[:,1] == "haploid-R3") & (meta[:,5] == parent) & (meta[:,2] == p))[0]
        if len(I1)*len(I2) == 0:
            LOG.debug("Skipping plate %s. R1/R6 %d total; R3 %d total"%(p, len(I1), len(I2)))
            continue
            
        for o in orfs:
            i1 = SP.where(meta[I1,4] == o)[0]
            i2 = SP.where(meta[I2,4] == o)[0]
            if len(i1)*len(i2) != 1: continue 
            i1, i2 = i1[0], i2[0] # Get the indices corresponding to the replicate measurements of each ORF. If not present in exactly one copy, skip ORF.
            
            if slice[I1[i1],6] < min_cells or slice[I2[i2],6] < min_cells: continue # Do not consider ORFs that don't have enough cells observed
            vals.append([slice[I1[i1], trait], slice[I2[i2], trait]]) # Trait 3 = mean GFP; trait 5 = var(gfp)
            uorfs.append(o)
        values[p], used_orfs[p] = SP.array(vals), SP.array(uorfs)

    return values, used_orfs


def get_cyto_well_orf(screen):
    header, meta, slice, all = read_cyto_tab("%s/summary/cyto/Pilot_screen_BYxRM/clean_haploid.tab"%DATA_DIR)
    result = {}
    for (e,s,pl,w,o,pa,ploidy) in meta:
        if s == screen:
            result[w] = o
    return result



def get_gfpswap_annot():
    ann = SP.loadtxt("%s/cytometry/Pilot_screen_BYxRM/annotation/gfpswaps/GFPswap_haploids_R2.tab"%DATA_DIR, delimiter="\t", dtype=object)
    orf = SP.zeros([8,12], object)
    for i in range(len(ann)):
        r,c,o = ann[i][[0,1,4]]
        orf[(int(r)-1)/2, (int(c) - 1)/2] = o
    return orf    
        

def get_gfpswap_data(debug=True, reps=(2,3), recalc=False, cutoff=400):
    orfs = get_gfpswap_annot()
    outfilename = "%s/cytometry/Pilot_screen_BYxRM/gfpswaps/R%s.pickle"%(DATA_DIR, "".join(map(str, reps)))
    data, counts = {}, {}
    if os.path.exists(outfilename) and (not recalc):
        data,counts = cl(outfilename)
    else:
        parents = [["BY-BYgfp", "RM-BYgfp"],["BY-RMgfp","RM-RMgfp"]]
        data,counts = {},{}
        for p1,p2 in parents:
            data[p1] = SP.zeros([len(reps),8,12])*SP.nan # 2 replicates, 96-well
            data[p2] = SP.zeros([len(reps),8,12])*SP.nan # 2 replicates, 96-well
            counts[p1] = SP.zeros([len(reps),8,12])*SP.nan # 2 replicates, 96-well
            counts[p2] = SP.zeros([len(reps),8,12])*SP.nan # 2 replicates, 96-well

        for rep in reps: 
            for f in sorted(glob.glob("%s/cytometry/Pilot_screen_BYxRM/gfpswaps/*%d/*.fcs"%(DATA_DIR,rep))):
                well = f.split("/")[-1].split("_")[2]
                r,c = ord(well[0]) - ord('A'), int(well[1:]) - 1
                fcs = SP.log10(read_fcs(f)[1])
                if len(str(orfs[r/2,c/2])) < 2: print f.split("/")[-1], r, c, len(fcs)                
                I = SP.where((abs(fcs[:,3] - 4.83) < 0.03) & (abs(fcs[:,2] - 4.45) < 0.25) & ((abs(fcs[:,1] - 3.2) < 0.8)))[0]
                pi,pj = r%2, c%2
                if rep == 2: pi = 1-pi # Replicate 2 plate was swapped
                counts[parents[pi][pj]][rep - min(reps), r/2, c/2] = len(fcs[I,0])
                if len(I) < cutoff:
                    data[parents[pi][pj]][rep - min(reps), r/2, c/2] = SP.median(fcs[I,0])
                else:
                    data[parents[pi][pj]][rep - min(reps), r/2, c/2] = SP.nan
                if debug: LOG.debug("%s\t%s\t%d\t%d\t%s\t%d\t%d\t%.2f"%(f, well, r, c, parents[r%2][c%2], len(fcs), len(I), SP.median(fcs[I,0])))

        cdm((data,counts), outfilename)
    return data, counts, orfs


def get_cyto_gfpswaps(screen="haploid-R3"):
    orfs = get_gfpswap_annot()
    all_cis_orfs = SP.unique(reduce(lambda x,y:list(x)+list(y), orfs))
    cyto = SP.loadtxt("%s/summary/cyto/Pilot_screen_BYxRM/paper_haploid.tab"%DATA_DIR, delimiter="\t", dtype=object)
    cytoval = {"BY-BYgfp":SP.zeros([8,12])*SP.nan, "RM-BYgfp":SP.zeros([8,12])*SP.nan}
    for i in range(len(cyto)):
        if cyto[i][1] != screen: continue
        if cyto[i][4] not in all_cis_orfs: continue
        o = SP.where(orfs == cyto[i][4])
        cytoval["%s-BYgfp"%(cyto[i][5])][o] = float(cyto[i][8])
    return cytoval


def get_cis_mrna_signals():
    orfs = get_gfpswap_annot()
    all_cis_orfs = SP.unique(reduce(lambda x,y:list(x)+list(y), orfs))
    res = SP.zeros([8,12])*SP.nan
    d = SP.loadtxt("%s/mrna/smith_linkage_results_thinned_cis.tab"%DATA_DIR, delimiter="\t", dtype=object)
    for i in range(len(d)):
	if d[i,0] in all_cis_orfs:
	    o = SP.where(orfs == d[i,0])
	    #res[o] = float(d[i,7]) - float(d[i,6])
	    res[o] = float(d[i,3])
    return res




def get_reproducibility_flags(trait=1, cutoff=0.2, plates=None, meta=None, parentals=None, slice=None, debug=False):
    plate_orf_parent_ok = {}

    for plate in plates:
        if plate.count("fillin") > 0: continue # no need to align these guys
        plate_orf_parent_ok[plate] = {}
        # first copy always R1 (acquisition R3), second the other
        I1 = (meta[:,1] == "haploid-R1") & (meta[:,2] == plate)
        I2 = (meta[:,1] != "haploid-R1") & (meta[:,2] == plate)
        for parent in parentals:
            plate_orf_parent_ok[plate][parent] = {}
            Ip = (meta[:,5] == parent)
            t = [{},{}] # map of ORF to trait value for both plates
            orfs = [[],[]]
            for i,I in enumerate([I1, I2]): # for both plates
                for j in SP.where(Ip & I)[0]: # and each entry
                    t[i][meta[j,4]] = slice[j,trait] # store the value for the ORF in entry
                    orfs[i].append(meta[j,4]) # add ORF to list
                    plate_orf_parent_ok[plate][parent][meta[j,4]] = MISSING # init goodness calls for ORF
                    
            v1, v2, shared_orfs = [], [], list(set(orfs[0]) & set(orfs[1])) # Create vectors of values for shared ORFs
            for o in shared_orfs:
                v1.append(t[0][o])
                v2.append(t[1][o])
            # align second to first, and get the good (well aligned) and bad sets
            good,bad,error,slope,intercept = is_good(SP.array(v1), SP.array(v2), cutoff)
            for g in good: plate_orf_parent_ok[plate][parent][shared_orfs[g]] = GOOD
            for b in bad: plate_orf_parent_ok[plate][parent][shared_orfs[b]] = BAD
            if debug:
                v = SP.array(plate_orf_parent_ok[plate][parent].values())
                print trait, plate, parent
                for v0 in SP.unique(v):
                    print "\t", v0, sum(v == v0)
                    
    return plate_orf_parent_ok



""" Assume standard layout of input file, log scale all features
@return replicate->parent->plate->[16,24,6] (data), replicate->parent->plate->[16,24,5] (meta) 
"""
def get_plate_screen_parent_vals(infilename="%s/summary/paper/table_S2-singles.tab"%DATA_DIR):
    data = SP.loadtxt(infilename, delimiter="\t", dtype=object)
    # 0. Keep track of where each ORF is (plate, and well)
    orf_platewell = {}
    for i in range(1, len(data)):
	screen, plate, well, orf, parent = data[i][1:6]
        if screen == "haploid-R1": orf_platewell[orf] = (int(plate[1:]), ord(well[0]) - ord('A'), int(well[1:])-1)  # "R1" assumes "R3" has been renamed to "R1"
    orf_platewell["YMR288W"], orf_platewell["YGL201C"] = (11,1,5), (11,1,7)

    # 1. Create map of replicate->parent->plate->data. 
    res_data, res_meta = {}, {}
    for i in range(1, len(data)):
        screen, plate, well, orf, parent, ploidy, quality = data[i][1:8] # parse metadata
        if (screen == "fillin"): continue # for combined file, do not use fillin values
	rep = "Rep2" if screen == "haploid-R1" else "Rep1" # again, assumes "R3" has been renamed to "R1"
        if (orf in ["YMR288W", "YGL201C"]) and (screen == "haploid-R1") and (plate != "P11"): continue # inter-machine compatibility for special case
        p,r,c = orf_platewell[orf] # get location of the ORF in the array
        if rep not in res_data: res_data[rep], res_meta[rep] = {},{} # initialise vars if needed
        if parent not in res_data[rep]: res_data[rep][parent], res_meta[rep][parent] = {},{}
        if p not in res_data[rep][parent]: res_data[rep][parent][p], res_meta[rep][parent][p] = SP.zeros([16,24,14], float)*SP.nan, {}
        res_data[rep][parent][p][r,c,:] = map(float, data[i][8:]) # store data
        res_meta[rep][parent][p][(r,c)] = data[i][0:8] # and metadata
    return res_data, res_meta


	
"""
@param trait 0 for linear mean, 2 for linear median
@return 
"""
def get_combined_cyto_vals(infilename="%s/paper/table_S2-combined.tab"%DATA_DIR, trait=0):
    LOG.debug("Getting cytometry values: %s"%infilename)
    data = SP.loadtxt(infilename, delimiter="\t", dtype=object)
    res = []
    for i in range(1, len(data)):
        #orf, plate, well = data[i][0:3] # parse metadata
        res.append([float(data[i][9 + trait]), float(data[i][35 + trait])]) 
    return SP.array(res)


""" Trait=4 is linear median
@param traits
@param cutoffs
@param names
@param min_cells
"""
def flag_replicates(infilename="%s/summary/paper/table_S2-singles.tab"%DATA_DIR, traits=[4,5], cutoffs=[0.2,1], names=["mean","var"], min_cells=200, debug=False, outfilename="%s/summary/cyto/Pilot_screen_BYxRM/paper_haploid_annotated.tab"%DATA_DIR):
    meta, slice, parentals, orfs, plates, header, all = read_cyto_tab(infilename, includes_linear=True, normalise=False, return_header=True, log_linear=True, return_all=True)
    flags = []

    for t in range(len(traits)):
        flags.append(get_reproducibility_flags(traits[t], cutoffs[t], plates, meta, parentals, slice, debug))

    ofh = file(outfilename, "w")
    ofh.write("\t".join(header[0:7]) + "\tFlag\t" + "\t".join(header[7:]) + "\n")
    for i in range(len(meta)):
        flag = GOOD if slice[i,6] >= min_cells else BAD + "-count"
        orf, plate = meta[i,4], meta[i,2]

        # present in both plates and parents with ok?
        for f, status in enumerate(flags):
            for p in parentals:
                if plate not in status:
                    flag = MISSING
                elif orf in status[plate][p]:
                    if status[plate][p][orf][0:3] == BAD[0:3]:
                        if flag[0:3] == BAD[0:3]:
                            if flag.count(names[f]) == 0: flag = flag + "-" + names[f] # if already bad, add this cause
                        else: flag = BAD + "-" + names[f] # else just say bad
                    elif status[plate][p][orf] == MISSING:
                        if flag != BAD: flag = MISSING
                else:
                    flag = MISSING
        ofh.write("\t".join(meta[i]) + "\t%s\t"%flag + "\t".join(["%.4f"%x for x in slice[i][0:-1]]) + "\t%d\t"%(slice[i,-1]))
	ofh.write("\t".join(["%.4f"%(x) for x in all[i][0:-1]]) + "\t%d\n"%(all[i,-1]))
    ofh.close()
 


def get_combined_cytometry_plate_values(crosses, reps, plate, data):
    vals = [[],[]]
    for p in range(len(crosses)):
        query = crosses[p].split("x")[0]
        for r in range(len(reps)): # align to Rep1-BY
            v = SP.zeros([384,14])
            print plate, query, reps[r], sum(SP.isnan(data[reps[r]][query][plate][:,:,0].reshape(384))),
	    for sl, slice in enumerate(['slice','all']):
                for s,scale in enumerate(['log','linear']):
                    for t in range(3):
                        trait = 7*sl + 3*s + t
                        v[:,trait] = align_vals(data[reps[r]][query][plate][:,:,trait].reshape(384), data["Rep1"]["BY"][plate][:,:,trait].reshape(384))
  		        if scale == "linear": v[:,trait] = 10**v[:,trait]
            v[:,6] = data[reps[r]][query][plate][:,:,6].reshape(384) # counts for slice are separate - not aligned
            v[:,13] = data[reps[r]][query][plate][:,:,13].reshape(384) # counts for all are separate - not aligned
	    vals[p].append(v.reshape([16,24,14]))
    return SP.array(vals) # Now have the values aligned to the Rep1 BY plate


""" Combine replicates from screens, align them, output all summaries. """
def create_combined_cytometry_output_table(infilename="%s/summary/paper/table_S2-singles.tab"%DATA_DIR, outfilename="%s/summary/cyto/table_S2-combined.tab"%DATA_DIR):
    LOG.debug("Creating combined cytometry output from %s to %s"%(infilename, outfilename))
    data, meta = get_plate_screen_parent_vals(infilename)
    crosses = ["BYxBY","RMxBY"]
    reps = ["Rep1", "Rep2"]
    slices = ["10pct", "all"]
    sets  = ["log10", "linear"]
    traits = ["mean", "median", "var"]

    # 0. Output header - already a lot of lines :)
    ofh = file(outfilename, 'w')
    ofh.write("ORF\tPlate\tWell")
    for cross in crosses:
        for slice in slices:	    
	    for setname in sets:
    	        for trait in traits:
		    for val in ["average", "MSE"]:
		        ofh.write("\t%s(%s_%s_%s_%s)"%(val, cross, slice, setname, trait))
    	    ofh.write("\t%s_%s_%s_number_of_cells"%(cross, slice, setname))
    for slice in slices:
        for setname in sets:
	    for trait in traits:
		for val in ["average", "MSE"]:
		    ofh.write("\t%s(Difference_in_%s_%s_%s)"%(val,setname, slice, trait))
    ofh.write("\n")
    total = 0
    total_processed = 0
    for plate in range(1,12):
        vals = get_combined_cytometry_plate_values(crosses, reps, plate, data)
	LOG.debug("Processing plate %s, data shape=%s"%(plate, str(vals.shape)))
        total += SP.prod(vals.shape)
        n_skipped = 0
        n_processed = 0
	for r in range(16):
	    for c in range(24):
		if (r,c) not in meta["Rep1"]["BY"][plate]:
                    n_skipped += 1
                    continue
		m = meta["Rep1"]["BY"][plate][(r,c)]
	        if m[1] == "fillin": continue
                if m[-1] != "OK":
                    n_skipped += 1
                    continue
                n_processed += 1
		row, col = chr(ord('A') + r), c + 1
		ofh.write("%s\t%s\t%s%d"%(m[4],m[2], row, col)) # ORF, plate, well, m(BY) sd(BY), m(RM) sd(RM) etc
		for p in range(2):
		    for start in [0,7]: # this is 
  		        for i in range(6): # log and linear in a row!
		            ofh.write("\t%.4f\t%.4f"%(vals[p,:,r,c,start+i].mean(),vals[p,:,r,c,start+i].var()**0.5))
			if SP.isnan(vals[p,0,r,c,start+6]): vals[p,0,r,c,start+6] = -1
			if SP.isnan(vals[p,1,r,c,start+6]): vals[p,1,r,c,start+6] = -1
			ofh.write("\t%d,%d"%(vals[p,0,r,c,start+6],vals[p,1,r,c,start+6]))
		for start in [0,7]:
    		    for i in range(6):
		        ofh.write("\t%.4f"%(vals[1,:,r,c,start+i].mean() - vals[0,:,r,c,start+i].mean()))
			ofh.write("\t%.4f"%((vals[1,:,r,c,start+i] - vals[0,:,r,c,start+i]).var()**0.5))
		ofh.write("\n")
        total_processed += n_processed
        LOG.debug("Skipped %d wells, processed %d"%(n_skipped, n_processed))
    LOG.debug("Total %d wells, processed %d"%(total, total_processed))
    ofh.close()


def get_confirmation_signal(debug=False, recalc=False):
    toplot = {}

    if recalc:
        for f in sorted(glob.glob("%s/cytometry/Pilot_screen_BYxRM/conf/RH2/*.fcs"%DATA_DIR)):
            gene, allele, clone, techrep, conc = f.split("/")[-1][0:-4].split("_")
            if gene not in toplot: toplot[gene] = {}
            if allele not in toplot[gene]: toplot[gene][allele] = SP.zeros([2,8,4])
            d = SP.log10(read_fcs(f)[1])
            I = SP.where((abs(d[:,3] - 4.84) < 0.04) & (abs(d[:,2] - 4.65) < 0.25) & ((abs(d[:,1] - 2.3) < 0.5)))[0]
            if debug: print len(d), len(I), f.split("/")[-1]
            toplot[gene][allele][int(clone.split("-")[1]) - 1, int(techrep[1:]) - 1, int(conc[5:])] = SP.median(d[I,0])
        cdm(toplot, "%s/cytometry/Pilot_screen_BYxRM/conf/RH2/toplot.pickle"%DATA_DIR)
    else:
        toplot = cl("%s/cytometry/Pilot_screen_BYxRM/conf/RH2/toplot.pickle"%DATA_DIR)
    return toplot
    
 
def main():
    return
    for trait in [3,5]:
        for parent in ("RM", "BY"):
            figure_S2a(200,trait,parent)


if __name__ == '__main__':
    main()
