import scipy as SP
import pdb
from gfpvar.tools.common import *

MIN_CYTO_CELLS = 3000

#YJL092W	GFP(+)02	A3	A	3	1	1	6	HPR5	5190	nucleus
class cGene:
    def __init__(self, line):
        d = line.strip().split('\t')
        self.id = d[0]
        self.name = d[8]
        if len(self.name) < 2: self.name = None
        self.mRNA = None
        self.mRNA_h2 = None
        if len(d[9]) < 2 or d[9][0:3] == "not" or d[9][0:3] == "low" or d[9][0:3] == "tec": self.gfp_ref = None 
        else: self.gfp_ref = float(d[9])
        self.gfp_by_mean = None
        self.gfp_by_var = None
        self.gfp_rm_mean = None
        self.gfp_rm_var = None
        self.gfp_h2 = None
        self.compartment = d[10]
        self.plate96 = d[1]
        self.plate384 = d[5]
        self.coord96 = d[2]
        self.coord384 = "r%s_c%s"%(d[6],d[7])
    
    def __str__(self):
        return "%s (%s) - 384-%s %s / 96-%s %s, mRNA=%s, GFP=%s (%s)"%(self.id, self.name, self.plate384, self.coord384, self.plate96, self.coord96, self.mRNA, self.gfp, self.compartment)

    def h2str(self):
        return "%s\tBy:%.2f(%.2f)\tRm:%.2f(%.2f)\tSeg:%.2f(%.2f)\th2:%.2f %.2f"%(self.id, self.mRNA[0], self.mRNA[1]**0.5, self.mRNA[2], self.mRNA[3]**0.5, self.mRNA[4], self.mRNA[5]**0.5, self.mRNA_h2[0], self.mRNA_h2[1])





def read_expr(filename='puma_segregants.tab'):
    ifh = file(filename, 'r')
    strains = ifh.next().strip().split('||')[1:]
    Irm = [x for x in range(len(strains)) if strains[x].count("RM") > 0 and strains[x].count("glu") > 0]
    Iby = [x for x in range(len(strains)) if strains[x].count("BY") > 0 and strains[x].count("glu") > 0]
    Iseg = [x for x in range(len(strains)) if strains[x].count("YLK") > 0 and strains[x].count("2% glu") > 0]
    I = Irm + Iby + Iseg
    ifh.next()
    expr = []
    genes = {}
    ii = 0
    for l in ifh: 
        d = SP.array(l.strip().split('\t')[3:])
        Ie = SP.where(d == '')[0]
        Ine = SP.where(d != '')[0]
        d[Ie] = SP.array(d[Ine],float).mean()
        expr.append([d[i] for i in I])
        genes[l.split('\t')[1].split('||')[5].strip()] = ii
        ii += 1
    n_rm, n_by, n_seg = map(len, [Irm, Iby, Iseg])
    expr = SP.array(expr,float)
    mean_expr_rm, var_expr_rm = expr[:,0:n_rm].mean(axis=1), expr[:, 0:n_rm].var(axis=1)
    mean_expr_by, var_expr_by = expr[:,n_rm:(n_rm+n_by)].mean(axis=1), expr[:, n_rm:(n_rm+n_by)].var(axis=1)
    mean_expr_seg, var_expr_seg = expr[:, (n_rm+n_by):].mean(axis=1), expr[:, (n_rm+n_by):].var(axis=1)
    return genes, (mean_expr_rm, var_expr_rm), (mean_expr_by, var_expr_by), (mean_expr_seg, var_expr_seg)
    


def read_plate(mrnadiff=1):
    d = [x.strip().split('\t') for x in file(FILE_COLLECTION_PLATES, "r")]
    genes, (expr_rm,var_rm), (expr_by, var_by), tmp = read_expr()
    exprdiff = expr_rm - expr_by

    compartments = []
    mrnas = []
    untested = []
    last_plate = -1

    for i in range(len(d)):
        plate = int(d[i][0]) - 1
        if last_plate != plate: 
            compartments.append({})
            mrnas.append([0,0])
            untested.append(0)
        last_plate = plate
        if len(d[i]) < 13: continue
        if len(d[i][12]) > 3:
            if d[i][12] not in compartments[-1]: compartments[-1][d[i][12]] = 0
            compartments[-1][d[i][12]] += 1
        else:
            untested[-1] += 1
        if d[i][9] in genes:
            mrnas[-1][0] += (exprdiff[genes[d[i][9]]] < -mrnadiff)
            mrnas[-1][1] += (exprdiff[genes[d[i][9]]] > mrnadiff)
    return compartments, untested, mrnas


#YJL092W	GFP(+)02	A3	A	3	1	1	6	HPR5	5190	nucleus
def read_gfp():
    genes = []
    for l in file(FILE_COLLECTION_PLATES, 'r'):
        if l[0:3] == "ORF" or l[0:5] == "BLANK": continue
        genes.append(cGene(l))
    return genes


def read_384_cyto(filename="gfp_cyto_initial.tab"):
    ifh = file(filename, 'r')
    headers = ifh.next()
    data = SP.zeros([2,2,8,12,6])
    good = SP.zeros([2,2,8,12], bool)
    blank = SP.zeros([2,2,8,12], bool)
#    gene_names = get_384_plate(1,1)
    i = -1
    for l in ifh:
        i += 1
        plate_in_384 = (i/24)%2 # rows alternate plate 1, plate 2
        row_96 = i/48
        col_96 = (i % 24)/2 # column index in 384-well format -> column index in 96 well format
        is_rm = (i%2) # samples alternate by / rm in columns

        d = l.strip().split("\t")[1:]
        if filename.count("AB") > 0 or filename.count("3.") > 0: 
#            d[0] = d[5] # median value
            d[2] = d[4] # count
            d = d[0:4]

        if int(d[-2]) < 1000: blank[plate_in_384,is_rm,row_96,col_96] = True
        if int(d[-2]) > MIN_CYTO_CELLS: 
            good[plate_in_384,is_rm,row_96,col_96] = True
        else:
            good[plate_in_384,:,row_96,col_96] = False

        d = SP.array(d, float)
        data[plate_in_384,is_rm,row_96,col_96,0:len(d)] = d

    return data, good, blank


# plate96 is in 1-index based
def get_plate_genes():
    result = SP.zeros([50, 8,12], object)

    for l in file(FILE_COLLECTION_PLATES, 'r'):
        sgd, plate, tmp, row, col = l.strip().split("\t")[0:5]

        if plate.count("GFP") > 0:
            plate96 = int(plate[6:])
            row = ord(row[0]) - ord('A')
            col = int(col) - 1
            result[plate96, row,col] = sgd
    return result


# plate is 1-index based
def get_blanks():
    result = SP.zeros([50,8,12], object)

    for l in file(FILE_COLLECTION_PLATES, 'r'):
        sgd, plate, tmp, row, col = l.strip().split("\t")[0:5]
        if plate.count("GFP") > 0:
            plate96 = int(plate[6:])
            row = ord(row[0]) - ord('A')
            col = int(col) - 1
            result[plate96, row,col] = (sgd=="BLANK")
    return result



def temp_get_mrnas(genes, data):    
    mrna, h2 = SP.zeros(genes.shape, float) - 1, SP.zeros(genes.shape, float) - 1
    for j in range(genes.shape[0]):
        for k in range(genes.shape[1]):
            if genes[j,k] in data and data[genes[j,k]].mRNA is not None:
                mrna[j,k] = data[genes[j,k]].mRNA[4]
                h2[j,k] = data[genes[j,k]].mRNA_h2[0]
    return mrna, h2


def temp_test_plate(filename="gfp_cyto_P1AB.tab", plot_h2=False, plates=None):
    data = get_data()
    cyto, good, cyto_blank = read_384_cyto(filename)
    gfp_blank = get_blanks()
    genes = get_plate_genes()
    if plates is None: plates = [1,3]

    for plate in plates:
        mrnas, h2s = temp_get_mrnas(genes[plate], data)
        for i in range(2):
            rna,prot,cyt,ph2,mh2 = [],[],[],[],[]

            for j in range(8):
                for k in range(12):
                    if good[i,0,j,k] and genes[plate+i,j,k] in data and data[genes[plate+i,j,k]].gfp_ref > 0:
                        rna.append(mrnas[j,k])
                        prot.append(SP.log10(data[genes[plate+i,j,k]].gfp_ref))
                        cyt.append(SP.log10(cyto[i,0,j,k,0]))
                        if prot[-1] > 3.5:
                            mh2.append(h2s[j,k])
                            s1,s2 = cyto[i,:,j,k,1]
                            m_highseg,m_parent = cyto[i,1,j,k,0], cyto[i,0,j,k,5]
                            #ph2.append((s2**2 - s1**2)/(s1**2)) # segregant var - parent var / parent var
                            ph2.append((0.5*(m_highseg-m_parent))**2/(s1**2))
#        pdb.set_trace()
#            print "%d - %.2f(%d/%d)\t"%(plate, SP.corrcoef(prot,cyt)[0,1], (gfp_blank[plate]*cyto_blank[i,0]).sum(), (gfp_blank[plate]*cyto_blank[i,1]).sum()),
#        print
        if plot_h2:
            import pylab as PL
            PL.figure()
            PL.subplot(221)
            PL.plot(ph2,mh2, ".", markersize=6)
            PL.xlabel("gfp h2")
            PL.ylabel("mRNA h2")
            PL.title("Plate %d - cor %.2f"%(plate, SP.corrcoef(ph2,mh2)[0,1]))
            PL.subplot(222)
            PL.plot(cyt, prot, ".", markersize=6)
            PL.xlabel("cyto")
            PL.ylabel("gfp array")
            PL.title("Plate %d - cor %.2f"%(plate, SP.corrcoef(prot,cyt)[0,1]))
            PL.subplot(223)
            PL.plot(cyt, rna, ".", markersize=6)
            PL.xlabel("cyto")
            PL.ylabel("mRNA")
            PL.title("Plate %d - cor %.2f"%(plate, SP.corrcoef(rna,cyt)[0,1]))
            PL.subplot(224)
            PL.plot(prot, rna, ".", markersize=6)
            PL.xlabel("gfp array")
            PL.ylabel("mRNA")
            PL.title("Plate %d - cor %.2f"%(plate, SP.corrcoef(rna,prot)[0,1]))
            pdb.set_trace()
    pass



def get_data(debug=False, plot=False):
    gfp = read_gfp()
    #cyto = read_cyto()
    rna_index, (rm_expr, rm_var), (by_expr, by_var), (segregant_expr, segregant_var) = read_expr()
    h2s = []
    for g in gfp:
        if g.id in rna_index: 
            i = rna_index[g.id]
            g.mRNA = (by_expr[i], by_var[i], rm_expr[i], rm_var[i], segregant_expr[i], segregant_var[i])
            g.mRNA_h2 = h2_estimate(*g.mRNA)
            if debug: print g.h2str()
            h2s.append(max(0,g.mRNA_h2[0]))
    if plot:
        import pylab as PL
        PL.hist(h2s, bins=40)
        PL.show()
    result = {}
    for g in gfp: result[g.id] = g
    return result


# calculate heritability of RNA trait given mean and variance of parentals, kids
# if p2 values are None, only one estimate is given
# if both parentals are known,
def h2_estimate(m_p1,v_p1, m_p2,v_p2, m_seg, v_seg):
    if m_p2 is None:
        return (v_seg - v_p1)/v_seg # Parental variance = private variance. H2 ~ proportion of non-private variance of total variance.
    else:
        s2g = (m_p1/2.-m_p2/.2)**2
        return (v_seg - v_p1)/v_seg, s2g/v_seg


def print_plate(given_plate=None):
    d = [x.strip().split('\t') for x in file(FILE_COLLECTION_PLATES, "r")]
    genes, expr_rm, expr_by = read_expr()

    for i in range(len(d)):
        plate = int(d[i][0]) - 1
        if given_plate is not None and plate != given_plate: continue # if not plate we want, skip
        if len(d[i]) < 13: continue # if not enough data, skip
        if d[i][9] in genes:
            print "%s\t%s\t%s\t%s\t%s\t%s\t%.1f\t%.1f\t%s"%(d[i][4], d[i][5], d[i][6], d[i][9], d[i][10], d[i][11], expr_rm[genes[d[i][9]]], expr_by[genes[d[i][9]]], d[i][12])



def main():
    import glob
    for f in glob.glob("gfp_cyto_P1A?.tab"):
        print f
        temp_test_plate(f,plot_h2=True, plates=[f.count("AB") + 3*f.count("CD")])
    get_data()
    return
    c,d,e = read_plate(1)
    for i in range(len(d)): 
        print "Plate", i+1
        for k in c[i]: print "\t%s:%d"%(k, c[i][k])
        for k in ['nucleus','cytoplasm','mitochondrion']: print "\t%s:%d"%(k, c[i][k])
        print "\tuntested:", d[i]    
        print "\tupreg: %d\n\tdownreg: %d"%(e[i][0], e[i][1])
    print SP.mean([x[0] for x in e]), SP.mean([x[1] for x in e])
    print_plate(9)

if __name__ == '__main__': main()
