import scipy as SP
import pylab as PL
import os
import glob
import pdb
from parts2014_gfpvar.tools.common import *
from parts2014_gfpvar.tools.io import *
from parts2014_gfpvar.analysis.cyto.util import *
from parts2014_gfpvar.tools.util import get_glucose_smith_target_orfs

n_col = 32*48
n_wells = 16*24
ALL_PARENTS = {"BYs": ["BY", "BY-HAP1", "BY-MKT1-IRA2", "BY-HAP1-MKT1-IRA2"],
           "RMs": ["RM", "RM-IRA2","RM-MKT1","RM-MKT1-IRA2"],
           "BYRM": ["BY", "BY-HAP1-MKT1-IRA2", "RM-MKT1-IRA2", "RM"]}

def get_tecans(parents=ALL_PARENTS, reps=(1,2), rep_name="R"):
    result = []
    for r in reps:
        result.append({})
        for screen in parents:
            result[-1][screen] = read_tecan_screen(screen, r, parents[screen], rep_name=rep_name)
    return result


def combine_tecans(vals, thresh=0.25*SP.log2(10), bad_cutoff=0.5, calc_sds=True):
    res = {}
    sds = {}
    
    for v,val in enumerate(vals):
        for screen in val:
            if screen not in res: res[screen], sds[screen] = {}, {}
            for plate in val[screen]:
                if plate not in res[screen]: # first time we see this, just take as default
                    res[screen][plate], sds[screen][plate] = {}, {}
                    for p in val[screen][plate]: # copy values for each parent
                        res[screen][plate][p] = SP.zeros([len(vals), 16, 24])*SP.nan
                        res[screen][plate][p][0] = val[screen][plate][p]
                else:
                    for p in val[screen][plate]: # if not first time, fill in values, aligning to the first observation
                        res[screen][plate][p][v] = align_vals(val[screen][plate][p].reshape(384), res[screen][plate][p][0].reshape(384), good_threshold=thresh).reshape([16,24])
    # Now have the values from each replicate, aligned. Calculate median, chuck out discordant ones
    for screen in res:
        for plate in res[screen]:
            for p in res[screen][plate]:
                I = SP.where(abs(res[screen][plate][p][0] - res[screen][plate][p][1]) > bad_cutoff)
		sds[screen][plate][p] = res[screen][plate][p].var(axis=0)**0.5
                res[screen][plate][p] = SP.median(res[screen][plate][p], axis=0) # for now, do not filter at all
                res[screen][plate][p][I] = SP.nan

    if calc_sds: return res, sds
    return res



def flag_tecan_plates(plate, refs, cutoff=0.4, good_threshold=0.25*SP.log2(10), debug=False):
    v0 = plate.reshape(384)
    n = len(refs)
    flags_plus = SP.zeros([n, 16,24], int)
    flags_minus = SP.zeros([n, 16,24], int)

    for r in range(n):
        v1 = align_vals(refs[r].reshape(384), v0, good_threshold=good_threshold)
        flags_plus[r] = (v0 - v1).reshape(16,24) > cutoff
        flags_minus[r] = (v1 - v0).reshape(16,24) > cutoff

    num_cutoff = min(3, n-1)
    return (flags_plus.sum(axis=0) >= num_cutoff) | (flags_minus.sum(axis=0) >= num_cutoff)# The flagged ones will be ones different from all other references (assume one is self)


def clean_tecan_plates(vals, cutoff=0.4):
    b0,b1,b2,b3,r0,r1,r2,r3 = "BY", "BY-HAP1", "BY-MKT1-IRA2", "BY-HAP1-MKT1-IRA2", "RM", "RM-MKT1", "RM-IRA2", "RM-MKT1-IRA2"
    sb,sr,sc = "BYs", "RMs", "BYRM"
    refs = {(sb,b0):[(sb,b1),(sb,b2),(sb,b3),(sc,b0),(sc,b3)],
            (sb,b1):[(sb,b0),(sb,b2),(sb,b3),(sc,b0),(sc,b3)],
            (sb,b2):[(sb,b1),(sb,b0),(sb,b3),(sc,b0),(sc,b3)],
            (sb,b3):[(sb,b1),(sb,b2),(sb,b0),(sc,b0),(sc,b3)]}
    nans = {}
    for screen, parent in refs:
        for plate in vals[screen]:
            r = [vals[s][plate][p] for (s,p) in refs[(screen,parent)]]
            I = SP.where(flag_tecan_plates(vals[screen][plate][parent], r, cutoff, debug=False))
            vals[screen][plate][parent][I] = SP.nan
            if (screen, parent) not in nans: nans[(screen, parent)] = 0
            nans[(screen, parent)] += SP.isnan(vals[screen][plate][parent]).sum()
    for s,p in nans:
        LOG.debug("Screen %s, tecan/util.py %s - %d NaNs"%(s, p, nans[(s,p)]))


def output_tecans(outfilename, vals, sds, parents):
    orfnames = get_array_data(skip_nan=False)[:,:,:,0:2]
    ofh = file(outfilename, 'w')
    ofh.write("# ORF\tGene\tPlate\tWell")
    for p in parents["BYs"]: ofh.write("\tmean(%s)\tsd(%s)"%(p,p))
    ofh.write("\n")
    for i in range(len(vals["BYs"][0])):
        plate = i/384 + 1
        rem = i % 384
        row, col = rem/24, rem%24
        ofh.write("%s\t%s"%(tuple(orfnames[plate,row+1,col+1])))
        ofh.write("\t%d\t%s%d"%(plate, chr(ord('A') + row), col+1))
        for p in range(len(parents["BYs"])): ofh.write("\t%.2f\t%.2f"%(vals["BYs"][p][i], sds["BYs"][p][i]))
        ofh.write("\n")
    ofh.close()


def get_tecanlike_cyto():
    meta, slice, parentals, orfs, plates, header = read_cyto_tab("%s/summary/cyto/Pilot_screen_BYxRM/paper_haploid_annotated.tab"%DATA_DIR, normalise=True, return_header=True, includes_all=False, includes_flag=True)
    cyto_vals = {}
    for i in range(len(slice)):
        if (meta[i,1] == "haploid-R3") and (meta[i,7] == "OK"):
            plate, parent, well = int(meta[i,2][1:]), meta[i,5], meta[i,3]
            if plate not in cyto_vals: cyto_vals[plate] = {}
            if parent not in cyto_vals[plate]: cyto_vals[plate][parent] = SP.zeros([16,24])*SP.nan
            cyto_vals[plate][parent][ord(well[0])-ord('A'), int(well[1:])-1] = slice[i,0]
    return cyto_vals


def read_tecan_screen(screen, replicate, parents, rep_name="R"):
    result = {}

    for plate in range(1,12):
        result[plate] = {}        
        d = [read_tecan(f, parents) for f in glob.glob("%s/tecan/%s/%s%d/P%d-?.txt"%(DATA_DIR, screen, rep_name, replicate, plate))]
        for p in parents:
            original_vals = SP.array([d[i][p] for i in range(len(d))])
            if len(original_vals) == 0: continue
            vals = SP.zeros([len(original_vals), 384])
            vals[0] = original_vals[0].reshape(384)
            for i in range(1, len(original_vals)):
                vals[i] = align_vals(original_vals[1].reshape(16*24), vals[0], 0.25*SP.log2(10)) # align all other plates to the first one
            result[plate][p] = SP.median(vals, axis=0).reshape([16,24])
    return result


def read_tecan(filename, parents=["A","B","C","D"]):
    result = {}
    for p in parents: result[p] = SP.zeros([16,24])
    n_col = 32*48
    n_head = 31

    for l in file(filename, 'r').readlines()[n_head:(n_head+n_col)]:
        well, value = l.strip().split()
        r,c = 0,0 # zero-based row and column
        if (ord(well[1]) >= ord('A')) and (ord(well[1]) <= ord('Z')): # if two-letter row
            r = 1 + ord('Z') - ord('A') + ord(well[1]) - ord('A')
            c = int(well[2:]) - 1
        else:
            r = ord(well[0]) - ord('A')
            c = int(well[1:]) - 1
        p = 2*(r % 2) + (c % 2)
        result[parents[p]][r/2, c/2] = SP.log2(int(value))
        #if well.count("AF") > 0 or well.count("AE") > 0 or well.count("AA") > 0 or well.count("Z") > 0:
        #    print well, r, c, r/2, c/2, p
    return result
        

def get_screen_target_orfs(eqtl_lod_cutoff=5, parents=[]):
    orfs = {}
    for p in ["HAP1","MKT1","IRA2", "IRA1"]: 
	orfs[p] = set(get_glucose_smith_target_orfs(p, eqtl_lod_cutoff))
	LOG.debug("LOD cutoff: %.1f\tParent: %s - %d eQTLs"%(eqtl_lod_cutoff, p, len(orfs[p])))
        
    for parent in parents:
	if parent.count("-") == 0: continue
        orfs[parent] = reduce(lambda x,y: x | y, [orfs[p] for p in parent.split("-")[1:]])

    return orfs


def get_val_lists(vals, array_data, parents=["BY","BY-HAP1", "BY-MKT1-IRA2", "BY-HAP1-MKT1-IRA2"], correct=SP.log2(10), nulls=None, sds=None):
    val_lists = [[],[],[],[]]
    sd_lists = [[],[],[],[]]
    plate_orfs = []
    if nulls is None:
        nulls = {}
        for p in parents: nulls[p] = "BY"

    for plate in range(1,12):
        for p, parent in enumerate(parents):
            v0 = vals[plate][nulls[parent]].reshape(384)
            val_lists[p].extend(align_vals(vals[plate][parent].reshape(384), v0, good_threshold=0.25*correct)/correct)
            if sds is not None: sd_lists[p].extend(sds[plate][parent].reshape(384))
        plate_orfs.extend(array_data[plate,1:,:,0][:,1:].reshape(384))

    if sds is None:
        return map(SP.array, val_lists), SP.array(plate_orfs), parents
    else:
        return map(SP.array, val_lists), SP.array(plate_orfs), parents, map(SP.array, sd_lists)


def get_combined_val_lists(vals, sds, parents, threshold=0.25*SP.log2(10)):
    val_lists, sd_lists = {}, {}
    plate_orfs = None
    for screen in parents:
        nulls = {}
        for p in parents[screen]:
            nulls[p] = "BY"
            if (screen == "RMs") or ((screen == "BYRM") and (p.count("RM") > 0)): nulls[p] = "RM"
        val_lists[screen], plate_orfs, p, sd_lists[screen] = get_val_lists(vals[screen], get_array_data(skip_nan=False), parents[screen], nulls=nulls, sds=sds[screen])
        
    return val_lists, sd_lists, plate_orfs
    

# Read cytometry gap
def get_cyto_gap(filename, trait=0):
    # 0. Read data, create convenience arrays for different strata
    meta, slice, parentals, orfs, plates = read_cyto_tab(filename, normalise=True)
    parentals = SP.array(parentals)[[0,1,3,2]]
    parent_orf_count, parent_orf_vals, orf_parent_count, plate_orf_count = calculate_parental_counts_across_plates(parentals, orfs, plates, meta, slice, min_cells=200, trait=trait, all_plates_good_only=False)
    # For each parental, compile a list of values in each plate
    parent_plate_vals = {}
    for p in parentals:
        parent_plate_vals[p] = {}
        for pl, plate in enumerate(sort(plates)):
            parent_plate_vals[p][plate] = SP.array([parent_orf_vals[p][orf][pl] for orf in sorted(orfs)])

    vals = SP.zeros([len(orfs),len(plates),5])*SP.nan
    for i in range(len(orfs)): # for each ORF
        for p,parent in enumerate(parentals):
            vals[i,:,p] = parent_orf_vals[parentals[p]][orfs[i]]
        vals[i,:,4] = vals[i,:,0:3].sum(axis=1) - 2*(vals[i,:,0]) # expected value
    v = SP.nansum(vals,axis=1)
    v = v/((~SP.isnan(vals)).sum(axis=1))
    baseline_cyto = v[:,0]
    gap_null_cyto = v[:,3] - v[:,2] # observed - null of no change
    gap_cyto = v[:,3] - v[:,4] # observed - expected
    return gap_cyto, orfs

 
