import scipy as SP
import pylab as PL
import glob
import os
import sys
from parts2014_gfpvar.tools.common import *
from parts2014_gfpvar.tools.io import *
from util import *
from parts2014_gfpvar.analysis.qtl.util import *
from parts2014_gfpvar.analysis.cyto.util import *
from parts2014_gfpvar.analysis.qtl.comparison import *

def singleplot():
    PL.rcParams.update(PLOT_PARAMS)
    j = 0
    I = SP.where(abs(init[j] - 0.5) < 0.45)[0]
    PL.figure(figsize=(15,4))
    PL.plot(locs[j][I], init[j][I], ".", markersize=7, alpha=0.2, color=cols[j])
    PL.plot(locs[j], post[j], markersize=7, alpha=1, color=cols[j], lw=4)
    PL.plot([min(locs[j]), max(locs[j])], [0,0], 'k-')
    PL.plot([min(locs[j]), max(locs[j])], [1,1], 'k-')
    j = 2
    I = SP.where(abs(init[j] - 0.5) < 0.45)[0]
    PL.plot(locs[j][I], init[j][I], ".", markersize=7, alpha=0.2, color=cols[j])
    PL.plot(locs[j], post[j], markersize=7, alpha=1, color=cols[j], lw=4)
    PL.plot([min(locs[j]), max(locs[j])], [0,0], 'k-')
    PL.plot([min(locs[j]), max(locs[j])], [1,1], 'k-')
    #PL.title("%s %d"%(samples[j], sample_data[j][chrm][2].sum(axis=1).mean()))
    PL.ylim(-0.1,1.1)
    #PL.suptitle(get_seq_chr_names()[chrm])
    return


""" Plot allele frequency for one chromosome. Assume all plotting things are taken care of. Plot dots for ML AF estimates, solid line for smoothed versions.
@param locs length-L array of chromosomal coordinates
@param init length-L array of max likelihood estimates of locus allele frequencies (a/(a+b))"""
def plot_chrm_af(locs, init, post, color, rm=True):
    if rm: init, post = 1. - init, 1. - post
    PL.plot(locs, init, ".", markersize=1, alpha=1, color=color)
    PL.plot(locs, post, alpha=1, color=color, lw=4)
    PL.ylim(-0.1,1.1)
    PL.plot([min(locs), max(locs)], [0,0], 'k-')
    PL.plot([min(locs), max(locs)], [1,1], 'k-')


def plot_posterior_ci(locs, mean, sd, color, alpha_multiplier=0.1, rm=True):
    x_ci = SP.array(list(locs) + list(locs)[::-1])
    y_ci = SP.array(list(mean) + list(mean)[::-1])
    if rm: y_ci = 1. - y_ci
    sds = SP.array(list(sd) + list(-sd)[::-1])
    PL.fill(x_ci, y_ci + sds, color, alpha=alpha_multiplier)
    PL.fill(x_ci, y_ci + 2*sds, color, alpha=2*alpha_multiplier) 


def plot_posterior_diff(locs, means, sds, color, plot_ci=True, alpha_multiplier=0.1, diffline=0.23, rm=True):
    overlap = set(locs[0]) & set(locs[1])
    I = [[i for i in range(len(locs[j])) if locs[j][i] in overlap] for j in range(2)]
    d = means[0][I[0]] - means[1][I[1]]
    l = locs[0][I[0]]
    sd = 2*(sds[0][I[0]]**2 + sds[1][I[1]]**2)**0.5
    if rm: d = -d
    PL.plot(l, d, color + '-', lw=4)
    PL.plot([min(l), max(l)], [0,0], 'k-')
    if diffline is not None:
        PL.plot([min(l), max(l)], [diffline,diffline], 'k--')
        PL.plot([min(l), max(l)], [-diffline,-diffline], 'k--')
    if plot_ci:
        plot_posterior_ci(l, d, sd, color, alpha_multiplier)


""" Plot allele frequency for one chromosome. Assume all plotting things are taken care of. Plot dots for ML AF estimates, solid line for smoothed versions. """
def plot_chrm_coverage(locs, coverage, color):
    PL.plot(locs, coverage, ".", markersize=7, alpha=0.2, color=color)
    PL.plot([min(locs), max(locs)], [SP.median(coverage), SP.median(coverage)], 'k-')



def plot_qtl(set_name, sample_high, sample_low, chrm, peakloc, diff, sd, startloc, endloc, add=False, show=True, colors="rb", rec_cutoff=None, diff_ylim=None, diffline=None, title=True, filter=False, separate_plots=True, rm=True):
    data = [get_sample_seq_data(sample, rec_cutoff)[chrm] for sample in sample_high, sample_low]
    I = [SP.where((data[i][4] >= startloc) & (data[i][4] <= endloc))[0] for i in range(2)]
    locs = [data[i][4][I[i]] for i in range(2)]
    init, mean, post, bad = [data[i][0][I[i]] for i in range(2)], [data[i][1][I[i]] for i in range(2)], [data[i][2][I[i]] for i in range(2)], [data[i][3][I[i]] for i in range(2)]
    vars = [p.prod(axis=1)/((p.sum(axis=1)**2+0.1)*(p.sum(axis=1) + 1.1)) for p in post]
    chrnames = get_seq_chr_names()
    
    if not add: PL.figure(figsize=(11, 4 + 2.5*separate_plots))
    if separate_plots:
        PL.subplot(211)
        PL.plot([peakloc, peakloc], [0,1], "k--") # highlight the peak
    for i in range(2):
        if len(I[i]) < 2: continue
        J = range(len(locs[i]))
        if filter: J = SP.where(~bad[i])[0]
        plot_chrm_af(locs[i][J], init[i][J], mean[i][J], colors[i], rm=rm)
        plot_posterior_ci(locs[i][J], mean[i][J], 2.*(vars[i][J]**0.5), colors[i], rm=rm)
    PL.xlim(startloc, endloc)
    if title and (not add):  PL.title("%s QTL at chrm %s %d-%d (peak at %d, change=%.2f; %d sd)"%(set_name, chrnames[chrm], startloc, endloc, peakloc, diff, sd))
    PL.ylabel("RM allele frequency")
    if separate_plots:
        PL.subplot(212)
        plot_posterior_diff(locs, mean, [v**0.5 for v in vars], colors[0], diffline=diffline, rm=rm)
        PL.plot([peakloc, peakloc], [-2,2], "k--") # highlight the peak
        if diff_ylim is None: diff_ylim = (-0.5,0.5)
        PL.ylabel("Difference in BY allele frequency")
    else:
        overlap = set(locs[0]) & set(locs[1])
        I = [[i for i in range(len(locs[j])) if locs[j][i] in overlap] for j in range(2)]
        d = mean[0][I[0]] - mean[1][I[1]]
        if rm: d = -d
        PL.plot(locs[0][I[0]], d, "k-", lw=2, alpha=0.7)
        if peakloc is None:
            peakloc = locs[0][I[0]][SP.argmax(abs(d))]
        PL.plot([peakloc, peakloc], [0,1], "k--") # highlight the peak
    PL.xlim(startloc, endloc)
    PL.ylim(*diff_ylim)
    PL.xlabel("Chr %s"%(chrnames[chrm]))
    if show:
        PL.show()


def plot_genes(chrm, start, end, y, h):
    all_genelocs = read_gene_locs()
    genelocs = []
    for (c, s, e) in all_genelocs.values():
        if (chrm == c) and (s > start) and (e < end):
            genelocs.append([s,e])
    for (x1,x2) in genelocs:
        PL.fill([x1,x2,x2,x1], [y+h, y+h, y-h, y-h], lw=0.5, fill=True, color="k", alpha=0.3)


def plot_chr11_qtl():
    chrm = 'ref|NC_001143|'
    peakloc = 280000
    window = 40000
    colors = 'rb'
    for (sam, high, low) in [("A19_sR2", "Sample4_73", "Sample4_61"), ("B7_sR2", "Sample4_86", "Sample4_74"), ("B6_sR2", "Sample4_67", "Sample4_55")]:
        plot_qtl(sam,high,low, chrm, peakloc, 0.2, 2, peakloc-window, peakloc+window, add=True, show=False, colors=colors, rec_cutoff=0.9, diff_ylim=(-0.1,1.1), diffline=None, filter=True, separate_plots=False, rm=True)
        PL.show()


def plot_paper_qtl1():
    peakloc = 215000
    window = 90000
    colors = 'rb' #['#33ff33','#006600']
    #plot_qtl("X1_sR1", "Sample4_32", "Sample4_44", "ref|NC_001136|", peakloc, 0.2, 2, peakloc-window, peakloc+window, add=False, show=False, colors=colors, rec_cutoff=0.9, diffline=None, title=False, filter=True, separate_plots=False, diff_ylim=(-0.1,1.1))
    plot_qtl("X1_sR2", "Sample4_93", "Sample4_81", "ref|NC_001136|", peakloc, 0.2, 2, peakloc-window, peakloc+window, add=True, show=False, colors=colors, rec_cutoff=0.9, diff_ylim=(-0.1,1.1), diffline=None, title=False, filter=True, separate_plots=False, rm=True)
    y,h = -0.05, 0.03
    plot_genes(4, peakloc-window, peakloc+window, y=y, h=h)
    x1,x2 = 213000, 216000
    PL.fill([x1,x2,x2,x1], [y+h, y+h, y-h, y-h], lw=0.5, fill=True, color="k")
    PL.plot([x2, x2+3000],[y+h, 0.05], 'k-')
    PL.text(x2+4000, y+h+0.07, "RGT2")#, fontsize=16)
    PL.savefig("RGT2.svg")
    PL.show()


def plot_paper_qtl2():
    peakloc = 474000
    window = 100000
    colors = 'rb' #['#33ff33','#006600']
    #plot_qtl("X3_sR1", "Sample4_58", "Sample4_70", "ref|NC_001139|", peakloc, 0.2, 2, peakloc-window, peakloc+window, add=False, show=False, colors="rb", rec_cutoff=0.9, diffline=None, title=False, filter=True)
    plot_qtl("X3_sR2", "Sample4_11", "Sample4_23", "ref|NC_001139|", peakloc, 0.2, 2, peakloc-window, peakloc+window, add=False, show=False, colors=colors, rec_cutoff=0.9, diff_ylim=(-0.1,1.1), diffline=None, title=False, filter=True, separate_plots=False)
    y,h = -0.05, 0.03
    plot_genes(4, peakloc-window, peakloc+window, y=y, h=h)
    x1,x2 = 469000, 473000
    PL.fill([x1,x2,x2,x1], [y+h, y+h, y-h, y-h], lw=0.5, fill=True, color="k")
    PL.plot([x2, x2+3000],[y+h, 0.05], 'k-')
    PL.text(x2+4000, y+h+0.07, "PDR1", fontsize=16)
    PL.savefig("PDR1.svg")
    PL.show()


def plot_density(signal_cutoff=0.2, smoother=25000, highlight_peaks=False, debug=False):
    ROMAN_NAMES = "I II III IV V VI VII VIII IX X XI XII XIII XIV XV XVI".split()
    chrlens = get_cumulative_chrm_lengths()
    lenvals = sorted(chrlens.values())
    chrseps, chrmids = SP.array(SP.array(lenvals[1:-1])/smoother, int), [0.5*(lenvals[i] + lenvals[i+1])/smoother for i in range(len(lenvals)-1)]
    pqtl_count, eqtl_count, pqtl_peak, eqtl_peak = get_peak_comparison(debug=debug, signal_cutoff=signal_cutoff, smoother=smoother, peak_cutoff_pqtl=4, peak_cutoff_eqtl=SP.log2(30))
    y, linewidth = 9, 2
    
    for sep in chrseps:  PL.plot([sep,sep], [-11,pqtl_count.max() + 1], 'k--') # separator between chromosomes
    if highlight_peaks:
        I = SP.where(eqtl_count > SP.log2(30))[0]
        PL.plot(SP.arange(len(pqtl_count))[I], pqtl_count[I], "ro")

    PL.plot(pqtl_count, "-", lw=linewidth, alpha=1)
    PL.plot(-eqtl_count, "-", lw=linewidth, alpha=1)
    PL.xticks(chrmids, ROMAN_NAMES)
    PL.xlabel("Genomic coordinates")
    PL.xlim(0,len(pqtl_count))
    PL.ylabel("log2(eQTL count)    pQTL count")
    PL.ylim(-10, pqtl_count.max() + 1)

    

def plot_four(chrm, sample_data, samples, plot_coverage=False):
    cols = 'cgmcbgrcbgrcbgrc'*10
    original_samples = [s for s in samples]
    samples = [samples[i] for i in range(len(samples)) if chrm in sample_data[i]]
    if len(samples) != 4: print set(original_samples) - set(samples)
    sample_data = [sample_data[i] for i in range(len(samples)) if chrm in sample_data[i]]
    init = [d[chrm][0] for d in sample_data]
    post = [d[chrm][1] for d in sample_data]
    good = [~d[chrm][3] for d in sample_data]
    locs = [d[chrm][4] for d in sample_data]
    coverage = [d[chrm][5] for d in sample_data]

    n_samples = len(init)
    j = 0
    
    for i in range((n_samples - 1)/4 + 1):
        PL.figure(figsize=(15,10))
        for k in range(4):
            if j >= len(locs): continue
            PL.subplot(4,1,(j+1)%4)
            I = SP.where((coverage[j] < 2.5*SP.median(coverage[j])) & good[j])[0]
            if plot_coverage: plot_chrm_coverage(locs[j][I], coverage[j][I], cols[j])
            else: plot_chrm_af(locs[j][I], init[j][I], post[j][I], cols[j])
            PL.title("%s %d"%(samples[j], coverage[j].mean()))
            j += 1
    PL.suptitle(get_seq_chr_names()[chrm])


def plot_eight(chrms, sample_data, sample):    
    PL.figure(figsize=(15,10))
    for c in range(len(chrms)):
        chrm = chrms[c]
        PL.subplot(8,1,c+1)
        if chrm not in sample_data: continue
        plot_chrm_af(sample_data[chrm][4], sample_data[chrm][0], sample_data[chrm][1], "b")
        PL.title("%s"%(get_seq_chr_names()[chrms[c]]))
    PL.suptitle("%s %d"%(sample, SP.median(sample_data[chrms[-2]][5])))


def plot_chrs():
    #files = glob.glob("%s/seq/Sa*A14/3_af/*_af.pickle"%(DATA_DIR))
    files = []
    files = files + glob.glob("%s/seq/Sa*A14/3_af/*_af.pickle"%(DATA_DIR))
    files = files + glob.glob("%s/seq/Sa*Plate_F1/3_af/*_af.pickle"%(DATA_DIR))
    files = files + glob.glob("%s/seq/Sample_A2/3_af/*_af.pickle"%(DATA_DIR))
    files = files + glob.glob("%s/seq/Sa*SSA4/3_af/*_af.pickle"%(DATA_DIR))
    files = []
    files = files + glob.glob("%s/seq/*Plate_B2/3_af/*_af.pickle"%(DATA_DIR))
    files = files + glob.glob("%s/seq/*Plate_E2/3_af/*_af.pickle"%(DATA_DIR))
    files = files + glob.glob("%s/seq/*Plate_A2/3_af/*_af.pickle"%(DATA_DIR))
    files = files + glob.glob("%s/seq/*Plate_F1/3_af/*_af.pickle"%(DATA_DIR))
    files = []
    files = files + glob.glob("%s/seq/Sample2*Plate_A1/3_af/*_af.pickle"%(DATA_DIR))
    files = files + glob.glob("%s/seq/Sample2*Plate_C1/3_af/*_af.pickle"%(DATA_DIR))
    files = files + glob.glob("%s/seq/Sample2*Plate_D1/3_af/*_af.pickle"%(DATA_DIR))
    files = files + glob.glob("%s/seq/Sample2*Plate_F1/3_af/*_af.pickle"%(DATA_DIR))
    files = []
    files = files + glob.glob("%s/seq/Sample2*Plate_A1/3_af/*_af.pickle"%(DATA_DIR))
    files = files + glob.glob("%s/seq/Sample2*Plate_B1/3_af/*_af.pickle"%(DATA_DIR))
    files = files + glob.glob("%s/seq/Sample2*Plate_D1/3_af/*_af.pickle"%(DATA_DIR))
    files = files + glob.glob("%s/seq/Sample2*Plate_A2/3_af/*_af.pickle"%(DATA_DIR))
    files = []
    files = files + glob.glob("%s/seq/Sample2*Plate_H1/3_af/*_af.pickle"%(DATA_DIR))
    files = files + glob.glob("%s/seq/Sample_*Plate_H1/3_af/*_af.pickle"%(DATA_DIR))
    files = files + glob.glob("%s/seq/Sample2*Plate_D2/3_af/*_af.pickle"%(DATA_DIR))
    files = files + glob.glob("%s/seq/Sample_*Plate_D2/3_af/*_af.pickle"%(DATA_DIR))
    files = []
    files = files + glob.glob("%s/seq/Sample2*Plate_E1/3_af/*_af.pickle"%(DATA_DIR))
    files = files + glob.glob("%s/seq/Sample2*A18/3_af/*_af.pickle"%(DATA_DIR))
    files = files + glob.glob("%s/seq/Sample2*Plate_C1/3_af/*_af.pickle"%(DATA_DIR))
    files = files + glob.glob("%s/seq/Sample2_*A24/3_af/*_af.pickle"%(DATA_DIR))
    files = []
    files = files + glob.glob("%s/seq/Sample2*_F1/3_af/*_af.pickle"%(DATA_DIR))
    files = files + glob.glob("%s/seq/Sample2_A14/3_af/*_af.pickle"%(DATA_DIR))
    files = files + glob.glob("%s/seq/Sample4_2/3_af/*_af.pickle"%(DATA_DIR))
    files = files + glob.glob("%s/seq/Sample4_14/3_af/*_af.pickle"%(DATA_DIR))
    samples = [f.split("/")[-3] for f in files]
    all_data = [cl(f) for f in files]
    print files

    for chrm in all_data[0].keys():
        if chrm != "ref|NC_001146|": continue
        plot_four(chrm, all_data, samples, plot_coverage=False)
        
    PL.show()


def plot_chr4s():
    files = glob.glob("%s/seq/Sample2*/3_af/*.pickle"%DATA_DIR)
    samples = [f.split("/")[-3] for f in files]
    all_data = [cl(f) for f in files]
    
    for chrm in all_data[0].keys():
        if get_seq_chr_names()[chrm] != "XV":
            continue
        #if chrm != "ref|NC_001143|": continue
        for i in range((len(samples)+3)/4):
            start, end = 4*i, 4*i + 4
            plot_four(chrm, all_data[start:end], samples[start:end], plot_coverage=False)
            PL.show()


def plot_8chrs():
    files = sorted(glob.glob("%s/seq/Sample?_*/3_af/*3.pickle"%DATA_DIR))
    samples = [f.split("/")[-3] for f in files]
    all_data = [cl(f) for f in files]
    valid_chrm = "I II III IV V VI VII VIII IX X XI XII XIII XIV XV XVI".split()
    samplechrms = sorted([k for k in all_data[0].keys() if get_seq_chr_names()[k] in valid_chrm])

    for i in range(len(samples)):
        if samples[i] not in ["Sample4_48"]: continue #", "Sample4_15", "Sample4_28", "Sample4_40", "Sample2_B24p", "Sample2_B24m"]: continue # ["Sample4_59", "Sample4_71"]: continue
        plot_eight(samplechrms[0:8], all_data[i], samples[i])
        PL.show()
        plot_eight(samplechrms[8:16], all_data[i], samples[i])
        PL.show()


def plot_eqtl_pqtl_comparison(nearby=0, debug=False, plot_guides=False):
    vals = get_eqtl_pqtl_overlap(nearby=nearby, debug=debug)
    I1 = SP.where(~SP.isnan(vals[:,2]))[0]
    I0 = SP.where(SP.isnan(vals[:,2]))[0]
    x,y,xl,yl = 1.5,0.1,1.0, 0.6,
    vals[SP.where(abs(vals[:,0]) > xl)[0], 0] = SP.sign(vals[SP.where(abs(vals[:,0]) > xl)[0],0])*(xl - 0.05)
    vals[SP.where(abs(vals[:,1]) > yl)[0], 1] = SP.sign(vals[SP.where(abs(vals[:,1]) > yl)[0],1])*(yl - 0.05)
    vals[SP.where(abs(vals[:,2]) > yl)[0], 2] = SP.sign(vals[SP.where(abs(vals[:,2]) > yl)[0],2])*(yl - 0.05)
    PL.plot(vals[I0,0], vals[I0,1], "r.", alpha=0.4, markersize=16)
    PL.plot(vals[I1,0], vals[I1,2], "b.", alpha=0.4, markersize=16)
    if plot_guides:
        PL.plot([-x,x],[y,y], 'k--', alpha=0.5)
        PL.plot([-x,x],[-y,-y], 'k--', alpha=0.5)
    #PL.plot([-xl,xl],[-yl,yl], 'k')
    PL.plot([-x,x],[0,0], 'k-', alpha=0.5)
    PL.plot([0,0],[-yl,yl], 'k-', alpha=0.5)
    PL.xlim(-xl,xl); PL.ylim(-yl,yl)
    PL.xlabel("eQTL effect size")
    PL.ylabel("pQTL effect size")
    PL.savefig("eqtl_pqtl.svg")
    PL.show()
    return
    PL.figure(figsize=(6,5))
    I1 = SP.where(~SP.isnan(vals[:,2]))[0]
    I0 = SP.where(SP.isnan(vals[:,2]))[0]
    PL.plot(vals[I0,3], abs(vals[I0,1]), "r.", alpha=0.4, markersize=17)
    PL.plot(vals[I1,3], abs(vals[I1,2]), "b.", alpha=0.4, markersize=17)
    x,y,yl = 1.5,0.1,0.75
    PL.plot([-x,x],[y,y], 'k--', alpha=0.5)
    PL.plot([-x,x],[-y,-y], 'k--', alpha=0.5)
    #PL.plot([-x,x],[-yl,yl], 'k')
#    PL.ylim(-yl,yl)
    PL.xlabel("eQTL effect size")
    PL.ylabel("pQTL effect size")


    
def plot_pqtl_eqtl_comparison(p_cutoff=0.1, val_cutoff=0.2, plot_thresholded=False):
    vals = get_pqtl_eqtl_signal(by=False)
    vals = vals[SP.where((~SP.isnan(vals)).all(axis=1) & (vals[:,2] < p_cutoff))[0]]
    #vals = vals[SP.where((~SP.isnan(vals)).all(axis=1) & (abs(vals[:,0]) >= val_cutoff))[0]]    
    Ilarge = SP.where(abs(vals[:,0]) >= 0.2)[0]
    Ismall = SP.where(abs(vals[:,0]) < 0.2)[0]
    Ipsig = SP.where(vals[:,2] < p_cutoff)[0]
    Ipnonsig = SP.where(vals[:,2] >= p_cutoff)[0]
    PL.plot([vals[:,0].min()-0.1, vals[:,0].max()+0.1], [0,0], 'k-', alpha=0.5)
    PL.plot([vals[:,0].min()-0.1, vals[:,0].max()+0.1], [0.1,0.1], 'k--', alpha=0.5)
    PL.plot([vals[:,0].min()-0.1, vals[:,0].max()+0.1], [-0.1,-0.1], 'k--', alpha=0.5)
    PL.plot(vals[Ismall,0], vals[Ismall,1], "b.", markersize=15, alpha=0.1)
    PL.plot(vals[Ilarge,0], vals[Ilarge,1], "r.", markersize=15, alpha=0.2)
    #PL.plot(vals[Ipsig,0], vals[Ipsig,1], "b.", markersize=15, alpha=0.4)
    #PL.plot(vals[Ipnonsig,0], vals[Ipnonsig,1], "r.", markersize=15, alpha=0.2)
    PL.xlabel("pQTL effect size, r=%.2f/%.2f, rho=%.2f/%.2f"%(ST.pearsonr(vals[Ilarge,0], vals[Ilarge,1])[0], ST.pearsonr(vals[:,0], vals[:,1])[0], ST.spearmanr(vals[Ilarge,0], vals[Ilarge,1])[0], ST.spearmanr(vals[:,0], vals[:,1])[0] ))
    PL.ylabel("eQTL allele effect on mRNA level")
    #PL.xlim(-0.4, 0.4)
    #PL.ylim(-0.4, 0.4)
    #PL.show()    


def get_tecan_protein_signals(parental="BY-HAP1"):
    res = {}
    d = SP.loadtxt("%s/summary/tecan/paper_by_screens.tab"%DATA_DIR, delimiter="\t", dtype=object, comments='?')
    i_parental = SP.where(d[0] == "mean(%s)"%parental)[0][0]
    i_by = SP.where(d[0] == "mean(BY)")[0][0]
    for i in range(1, len(d)):
        res[d[i,0]] = float(d[i,i_parental]) - float(d[i, i_by])
    return res


def get_all_cyto_signals():
    meta, slice, parentals, orfs, plates, header = read_cyto_tab("%s/summary/cyto/Pilot_screen_BYxRM/paper_haploid_annotated.tab"%DATA_DIR, normalise=False, return_header=True, includes_all=False, includes_flag=True)
    array_data = get_array_data(skip_nan=False)
    vals = {"BY":{}, "RM":{}}
    for i in range(len(slice)):
        if (meta[i,1] == "haploid-R3") and (meta[i,7] == "OK"):
            plate, parent, well = int(meta[i,2][1:]), meta[i,5], meta[i,3]
            o = array_data[plate,ord(well[0])-ord('A')+1, int(well[1:]),0]
            vals[parent][o] = slice[i,1]
            if o == "YER053C-A":
                print plate, parent, well, plate,ord(well[0])-ord('A')+1, int(well[1:]), o, vals[parent][o], meta[i], slice[i]
    result = {}
    for o in vals["BY"]:
        if o in vals["RM"]: result[o] = vals["RM"][o] - vals["BY"][o]
    return result




def get_protein_signals(parental="BY-HAP1", setname="targets", replicate_threshold=0.3):
    meta, slice, parentals, orfs, plates = read_cyto_tab("%s/summary/cyto/Pilot_screen_BYxRM/%s.tab"%(DATA_DIR, setname), normalise=True)
    orfs = sorted(orfs)
    parent_orf_count, parent_orf_vals, orf_parent_count, plate_orf_count = calculate_parental_counts_across_plates(parentals, orfs, plates, meta, slice, min_cells=200, trait=4)
    parent_plate_vals = {}
    for p in parentals:
        parent_plate_vals[p] = {}
        for pl, plate in enumerate(sort(plates)):
            parent_plate_vals[p][plate] = SP.array([parent_orf_vals[p][orf][pl] for orf in orfs])
    all_plates_ok = {}
    for o in orfs: all_plates_ok[o] = True
    p1, p2 = sort(plates)
    for p in parentals:
        x,y = parent_plate_vals[p][p1], parent_plate_vals[p][p2]
        diff = x - y
        badorf = SP.array([abs(diff[i] - SP.median(diff)) > replicate_threshold for i in range(len(diff))], bool)
        for o in SP.where(badorf)[0]: all_plates_ok[orfs[o]] = False
        
    array_orf_ok = SP.array([all_plates_ok[orf] for orf in sorted(orfs)], bool)    
    diffs = SP.array([parent_plate_vals[parental][plates[pl]] - parent_plate_vals["BY"][plates[pl]] for pl in range(2)])
    Igood = SP.where((~SP.isnan(diffs).any(axis=0)) & array_orf_ok)[0]
    result = {}
    for o in Igood: result[orfs[o]] = 0.5*(diffs[0][o] + diffs[1][o])
    return result


def plot_targetprotein_eqtl_comparison(loci=[(12,680000)], parental="BY-HAP1", debug=False, multipeak=False, tecan=False):
    ROMAN_NAMES = "0 I II III IV V VI VII VIII IX X XI XII XIII XIV XV XVI".split()
    
    mrna_signals = get_all_mrna_effect_signal(loci[0][0], loci[0][1], mrna_data=None)
    if multipeak:
        snps, snploc, expr_map = read_mrna_data()
        for o in expr_map:
            mrna_signals[o] = get_mrna_effect_signal_multipeak(o, [(ROMAN_NAMES[loci[0][0]], loci[0][1])], mrna_data=(snps,snploc,expr_map))
    protein_signals = get_protein_signals(parental)
    if tecan: protein_signals = get_tecan_protein_signals(parental)
    PL.figure(figsize=(6,5))
    vals = []
    for o in protein_signals:
        if o in mrna_signals:
            if o == "YER053C-A" and parental == "BY-HAP1": print o, mrna_signals[o], protein_signals[o], multipeak, tecan
            vals.append([mrna_signals[o], protein_signals[o]])
    vals = SP.array(vals)
    PL.plot(vals[:,0], vals[:,1], "b.", alpha=0.4, markersize=17)
    #x1,x2,y1,y2 = abs(vals[:,0]).min(), abs(vals[:,0]).max(), abs(vals[:,1]).min(), abs(vals[:,1]).max()
    #PL.plot([x1,x2],[y,y], 'k--', alpha=0.5)
    #PL.plot([-x,x],[-y,-y], 'k--', alpha=0.5)
    #PL.plot([-x,x],[-yl,yl], 'k')
    #PL.xlim(x1,x2)
    #PL.ylim(y1,y2)
    PL.xlabel("mRNA effect size")
    PL.ylabel("protein effect size")
    PL.title("%d %d - %s"%(loci[0][0], loci[0][1], parental))
    PL.savefig("%s.svg"%parental)
    #PL.show()



def tmp_plot_targetprotein_comparison(loci=[(12,680000)], parental="BY-HAP1", debug=False):
    ROMAN_NAMES = "0 I II III IV V VI VII VIII IX X XI XII XIII XIV XV XVI".split()
    x1 = get_protein_signals(parental)
    x2 = get_tecan_protein_signals(parental)
    o = "YER053C-A"
    print o, x1[o], x2[o]
    PL.figure(figsize=(6,5))
    vals = []
    for o in x1:
        if o in x2:
            vals.append([x1[o], x2[o]])
    vals = SP.array(vals)
    PL.plot(vals[:,0], vals[:,1], "b.", alpha=0.4, markersize=17)
    #x1,x2,y1,y2 = abs(vals[:,0]).min(), abs(vals[:,0]).max(), abs(vals[:,1]).min(), abs(vals[:,1]).max()
    #PL.plot([x1,x2],[y,y], 'k--', alpha=0.5)
    #PL.plot([-x,x],[-y,-y], 'k--', alpha=0.5)
    #PL.plot([-x,x],[-yl,yl], 'k')
    #PL.xlim(x1,x2)
    #PL.ylim(y1,y2)
    PL.xlabel("cyto effect size")
    PL.ylabel("tecan effect size")
    PL.title("%d %d - %s"%(loci[0][0], loci[0][1], parental))
    PL.savefig("%s.svg"%parental)
    #PL.show()



def main():
    plot_chr11_qtl()
    return
    PL.figure(None, [3.5,3.5])
    plot_pqtl_eqtl_comparison(p_cutoff=2, val_cutoff=0.2)
    PL.show()
    return
    plot_density(0.15)
    PL.show()
    return
    plot_eqtl_pqtl_comparison(nearby=0)
    return
    plot_density(0.15, smoother=30000)
    return
    PLOT_PARAMS = {'text.fontsize':40, 'xtick.labelsize':24, 'ytick.labelsize':24, 'text.size':40, 'axes.titlesize':30, 'axes.labelsize':26, 'figure.figsize':(16,10), 'legend.fontsize':24}
    PL.rcParams.update(PLOT_PARAMS)    
    plot_paper_qtl1()
    plot_paper_qtl2()
    return
    #plot_paper_qtl1()
    #return
    for tecan in [True, False]:
        for multipeak in [False]: #[True, False]:
            for chr, loc, strain in [(12,660000, "BY-HAP1"), (15,172000, "BY-MKT1-IRA2")]:#, (14,461000, "BY-MKT1-IRA2")]:
                plot_targetprotein_eqtl_comparison([(chr,loc)], strain, multipeak=multipeak, tecan=tecan)
    PL.show()
    for chr, loc, strain in [(12,660000, "BY-HAP1"), (15,172000, "BY-MKT1-IRA2")]:#, (14,461000, "BY-MKT1-IRA2")]:
        tmp_plot_targetprotein_comparison([(chr,loc)], strain)
    PL.show()
    #plot_paper_qtl2()
    #plot_pqtl_eqtl_comparison()
    #plot_eqtl_pqtl_comparison(nearby=20000, debug=True)
    #plot_density(0.15)
    #return
    #plot_density()

if __name__ == '__main__':
    main()
