import pylab as PL
import scipy as SP
import scipy.stats as ST
from parts2014_gfpvar.analysis.tecan.util import *
from parts2014_gfpvar.analysis.cyto.util import align_vals
from parts2014_gfpvar.analysis.qtl.util import get_mrna_signals

def plot_separate_scatter():
    PL.figure(figsize=(14,20))
    pi = 1
    for plate in set1:
        v1, v2 = set1[plate][parent1], set2[plate][parent2]
        if reshape: v1, v2 = v1.reshape(n_wells), v2.reshape(n_wells)
        if align: v2 = align_vals(v2,v1,good_threshold=threshold)
        PL.subplot(4,3,pi); pi += 1
        PL.plot(v1, v2, ".", markersize=14, alpha=0.1)
        PL.xlabel("Set1 (%s)"%parent1)
        PL.ylabel("Set2 (%s)"%parent2)
        I = SP.where((~SP.isnan(v1)) & (~SP.isnan(v2)))[0]
        PL.title("%s/%s plate %d, r=%.2f"%(parent1, parent2, plate, SP.corrcoef(v1[I], v2[I])[0,1]))
    PL.show()

    
def plot_scatter(parent1, parent2, set1, set2, reshape=True, align=True, threshold=0.25*SP.log2(10), lim=None, hist=False, histlim=(-0.5,0.5), histlog=False, highlight=None):
    PL.figure(figsize=(7+7*hist,6))
    x,y = [],[]
    for plate in set1:
        v1, v2 = set1[plate][parent1], set2[plate][parent2]
        if reshape: v1, v2 = v1.reshape(n_wells), v2.reshape(n_wells)
        x.extend(v1)
        if align: v2 = align_vals(v2,v1,good_threshold=threshold)
        y.extend(v2)
        
    x,y = SP.array(x), SP.array(y)
    PL.subplot(1, 1+hist, 1)
    PL.plot(x,y, ".", markersize=14, alpha=0.1)
    if highlight is not None:
        Ir = SP.where(abs(x-y) > highlight)[0]
        PL.plot(x[Ir],y[Ir], "r.", markersize=2)
    PL.xlabel("Set1 (%s)"%parent1)
    PL.ylabel("Set2 (%s)"%parent2)
    I = SP.where((~SP.isnan(x)) & (~SP.isnan(y)))[0]
    PL.title("%s/%s r^2=%.2f"%(parent1, parent2, SP.corrcoef(x[I], y[I])[0,1]**2))
    if lim is not None: PL.xlim(lim); PL.ylim(lim)
    if hist:
        PL.subplot(122)
        PL.hist(x-y, bins=SP.arange(-10,10,0.025), alpha=0.8, log=histlog)
        PL.xlim(*histlim)
        PL.xlabel("%s - %s"%(parent1, parent2))
    PL.show()


def plot_diff_scatter(nc, nr, pi, d1, d2, xlabel, ylabel, title, xlim=(-0.15,0.15), ylim=(-0.15,0.15)):
    PL.subplot(nc, nr, pi)
    pi += 1
    PL.plot(d1, d2, ".", markersize=14, alpha=0.1)
    PL.xlabel(xlabel)
    PL.ylabel(ylabel)
    if xlim is not None: PL.xlim(xlim)
    if ylim is not None: PL.ylim(ylim)
    I = SP.where((~SP.isnan(d1)) & (~SP.isnan(d2)))[0]
    PL.title("%s, rho^2=%.2f, r^2=%.2f"%(title, ST.spearmanr(d1,d2)[0]**2, SP.corrcoef(d1[I],d2[I])[0,1]**2))
    return pi

    
def plot_contributions(vs, tests, alt=None):
    PL.figure(figsize=(17 + 4*plot_alt,40))
    nc, nr, pi = 11, 4 + (alt is not None), 1
    
    for plate in vs:
        for i in range(1,len(vs[plate]) - 1):
            for j in range(i+1, len(vs[plate])):
                d1, d2 = vs[plate][i]-vs[plate][0], vs[plate][j]-vs[plate][0] # scatter plot of differences for the two parents 
                pi = plot_diff_scatter(nc, nr, pi, d1, d2, tests[i], tests[j], "Plate %d"%plate)
        
        d1, d2 = vs[plate][3]-vs[plate][0], vs[plate][1] + vs[plate][2] - 2*vs[plate][0] # as a last thing, prediction vs observation
        pi = plot_diff_scatter(nc, nr, pi, d1, d2, test[3], "Predicted", "Plate %d"%plate)
        if alt is not None:
            d1, d2 = vs[plate][3]-vs[plate][0], alt[plate] - vs[plate][0] # as an add-on, biological variance vs experimental noise
            pi = plot_diff_scatter(nc, nr, pi, d1, d2, tests[3], "Experimental noise", "Plate %d"%plate)
    PL.show()


    
def plot_combined_contributions(alls, tests, plot_raw=False, alls_alt=None, rep_data=None):
    nc, nr = 1 + plot_raw, 4 + (alls_alt is not None) + (rep_data is not None)
    PL.figure(figsize=(5*nr, 5*nc)); pi = 1

    # First row - differencse
    # Alternative null
    if alls_alt is not None: pi = plot_diff_scatter(nc, nr, pi, alls[3] - alls[0], alls_alt[3] - alls[0], tests[3] + " (alt)", "Predicted (alt)", "")
    # The allelic series (two singles and double)
    for i in range(1, len(alls)-1):
        for j in range(i+1, len(alls)):
            pi = plot_diff_scatter(nc, nr, pi, alls[i] - alls[0], alls[j] - alls[0], tests[i], tests[j], "")
    # Prediction
    pi = plot_diff_scatter(nc, nr, pi, alls[3] - alls[0], alls[1] - alls[0] + alls[2] - alls[0], tests[3], "Predicted", "")
    # Replication of raw data (to contrast to prediction)
    if rep_data is not None: pi = plot_diff_scatter(nc, nr, pi, alls[3] - alls[0], rep_data[0] - rep_data[1], tests[3], tests[3] + " (R2)", "")        

    # Second row - raw values
    if not plot_raw: return
    # Alternative null
    if alls_alt is not None: pi = plot_diff_scatter(nc, nr, pi, alls[3], alls_alt[3], tests[3] + " (alt)","Predicted (alt)", "", None, None)
    # The allelic series (two singles and double)
    for i in range(3):  pi = plot_diff_scatter(nc, nr, pi, alls[3], alls[i], tests[3], tests[i], "", None, None)
    # Prediction
    pi = plot_diff_scatter(nc, nr, pi, alls[3], alls[1] + alls[2] - alls[0], tests[3], "Predicted", "", None, None)
    # Replication of raw data (to contrast to prediction)
    if rep_data is not None: pi = plot_diff_scatter(nc, nr, pi, alls[3], rep_data[0], tests[3], tests[3] + " (R2)", "", None, None)

    PL.show()



def plot_eqtl_histogram(orfset, plate_orfs, v1, v0, title):
    I = SP.array([plate_orfs[x] in orfset for x in range(len(plate_orfs))], bool)
    I1 = SP.where(I)[0]
    I0 = SP.where(~I)[0]
    d1 = v0[I1] - v1[I1] # effect of v1
    d0 = v0[I0] - v1[I0]
    n1, n0 = len(I1), len(I0)
    for t in [0.1,0.18]:
        h1, h0 = sum(abs(d1) > t), sum(abs(d0) > t)
        print "%s - Over %.2f: + %d/%d (%.3f)  - %d/%d (%.3f). %.1f-fold enriched"%(title, t, h1, n1, 1.*h1/n1, h0, n0, 1.*h0/n0, 1.*h1*n0/h0/n1)
    b = SP.arange(-0.5,0.5, 0.025)
    PL.hist(d1, color="r", bins=b, alpha=0.5, normed=True, log=True)
    PL.hist(d0, color="b", bins=b, alpha=0.5, normed=True, log=True)
    PL.xlim(-0.3,0.3)
    PL.xlabel("Tecan effect size")
    PL.ylabel("Gene frequency (log10 scale)")
    PL.title(title)



def plot_mrna_signal(peaks, tecan_delta, andor=0, mrna_data=None, plate_orfs=None, peak_orfs=None, correct_qtls=True, xlim=0.3, ylim=None, screen_name="Tecan", title=False):
    d1, d2 = tecan_delta, get_mrna_signals(peaks, mrna_data, andor, plate_orfs, correct_qtls=correct_qtls)
    orfset = reduce(lambda x,y: x | y, [peak_orfs[peak] for peak in peaks])
    I = SP.array([plate_orfs[x] in orfset for x in range(len(plate_orfs))], bool)
    I1 = SP.where(I & (~SP.isnan(d1+d2)))[0]
    I0 = SP.where(~I & (~SP.isnan(d1+d2)))[0]
    if title: PL.title("%s - %d total, r=%.2f rho=%.2f"%(" ".join(peaks), len(I1), SP.corrcoef(d1[I1],d2[I1])[0,1], ST.spearmanr(d1[I1],d2[I1])[0]))
    else: PL.text(0.05, 0.9, "$\\rho=%.2f$"%(ST.spearmanr(d1[I1],d2[I1])[0]), transform=PL.gca().transAxes)
    if xlim is not None:
        d1[SP.where(d1 < -xlim)[0]] = -xlim
        d1[SP.where(d1 > xlim)[0]] = xlim
        PL.plot([-10,10], [0,0], 'k--')
    if ylim is not None:
        d2[SP.where(d2 < -ylim)[0]] = -ylim
        d2[SP.where(d2 > ylim)[0]] = ylim
        PL.plot([0,0], [-10,10], 'k--')
    PL.plot(d1[I0], d2[I0], "b.", markersize=14, alpha=0.05)
    PL.plot(d1[I1], d2[I1], "r.", markersize=10, alpha=0.3)
    if xlim is not None: PL.xlim(-xlim-0.01, xlim+0.01)
    if ylim is not None: PL.ylim(-ylim-0.05, ylim+0.05)
    PL.xlabel("     effect on %s"%(screen_name))
    add_names = False
    if add_names:
        PL.text(-0.1,-0.2, " ".join(peaks), style="italic")
        PL.text(-0.1,-0.2, " ".join(peaks), style="italic", rotation=90)
    PL.ylabel("     effect on mRNA level", )
 
