import pylab as PL
from matplotlib.backends.backend_pdf import PdfPages
from parts2014_gfpvar.tools.common import *
from parts2014_gfpvar.analysis.cyto.fcs import *
from quantify import *
from util import *
from cluster import read_clustered_well_data


""" Create a multipage PDF file for a candidate. One page for each technical replica, with the various scatter plots and statistics"""
def plot_candidate_replicas(orf, screens=["haploid-R1", "haploid-R3", "interesting_candidate", "reciprocal_candidate"], ploidies=["Haploid"], outfilename="%s/plots/candidate/test.pdf"%DATA_DIR, threshold=0.8, debug=False):
    pdf = PdfPages(outfilename) # init plot
    for screen in screens: # for each screen
        data = read_screen_results(screen) # read the data
        if orf not in data.orf_stats: # skip ones that didn't quantify required orf
            continue

        for i,stat in enumerate(data.orf_stats[orf]): # for each replica of the orf in different plates in the screen
            for w, well in enumerate(stat.wells):
                if debug:   print stat.screen, stat.plate, well, stat.ploidies[w], stat.ploidies[w] not in ploidies
                if stat.ploidies[w] not in ploidies: continue
                I, well_data, colnames = read_clustered_well_data(experiment="Pilot_screen_BYxRM", screen=screen, plate=stat.plate, well=well, threshold=threshold)
                if debug:
                    total = well_data.shape[0]
                    for cluster in I:
                        print "\t%s\t%d\t%.2f"%(cluster, sum(I[cluster]), 1.*sum(I[cluster])/total)
                        
                PL.figure()
                PL.subplot(221)
                cols = [1,3] # RFP-SSC
                plot_indexed_scatter(well_data[:,cols], I, colnames[cols])
                PL.subplot(222)
                cols = [0,2] # GFP-FSC
                plot_indexed_scatter(well_data[:,cols], I, colnames[cols])
                PL.subplot(223)
                plot_gfprfp_hist(well_data[:,1], I, "RFP")
                PL.subplot(224)
                plot_gfprfp_hist(well_data[:,0], I, "Unbudded GFP")
                PL.suptitle("%s (%s) - Screen %s plate %s well %s"%(orf, stat.common_name, screen, stat.plate, well))
                PL.savefig(pdf, format="pdf")
                
    pdf.close() # finalise pdf

""" Plot scatter of SSC-W vs RFP intensity given the raw data and the clusterings """
def plot_indexed_scatter(data, I, colnames):
    # get individuals with high
    cols = 'brgc'
    for i, k in enumerate(["BY unbudded", "BY budded", "RM unbudded", "RM budded"]):
        print i,k, sum(I[k]), cols[i]
        if sum(I[k]) > 0:
            PL.plot(data[I[k],0], data[I[k],1], cols[i] + ".", alpha=0.02)
    PL.xlabel(colnames[0])
    PL.ylabel(colnames[1])


def plot_single_hist(filename="%s/cytometry/Pilot_screen_BYxRM/haploid-R3/P9/Specimen_001_H6_H06.fcs"%DATA_DIR, normed=True, orf="YER053C-A", legend=True, yticks=None, xlim=None):
    cols, data = read_fcs(filename, colnames_tostore=DEFAULT_COLNAMES, log=True)
    data = filter_fcs(data, cols, DEFAULT_FILTERS)
    clustering = cl("%s/cytometry/Pilot_screen_BYxRM/haploid-R3/stats_fscasclicesize-0.1_rfpmodes-2_rndfiles-48_restarts-20_prior-data_points-100000_threshold-0.8/cluster-gmm_plate-P%s.pickle"%(DATA_DIR, filename.split("/")[-2][1:]))
    budding_gmm, rfp_gmms = clustering['budding_GMM'], clustering['rfp_GMMs']

    Icol = [cols.index(c) for c in clustering['features']]
    data_std = (data[:,Icol] - clustering['means'])/clustering['sds']
    # 1. cluster data, after standardising it in identical way to the learned clusters
    Icluster = [clustering['features'].index(feature) for feature in DEFAULT_BUDDING_CLUSTER_FEATURES]
    p_bud = budding_gmm.predict_proba(data_std[:,Icluster]) 
    # 2. cluster unbudded cells further based on RFP
    unbudded_cluster = budding_gmm._means[0:2,1].argmin() # unbudded cluster is the one of the first two that has smaller mean SSC-W (first index)
    p_rfp = rfp_gmms[unbudded_cluster].predict_proba(data_std[:, [clustering['features'].index('mCherry-A')]]) # cluster cells based on RFP
    if (rfp_gmms[unbudded_cluster].n_components > 2) and (rfp_gmms[unbudded_cluster]._means[0] > rfp_gmms[unbudded_cluster]._means[1]): # if clusters in wrong order, reorder assignments
        p_rfp = p_rfp[:,[1,0,2]] # switch first two clusters 
    # 3. calculate well statistics based on the clusterings
    wellstat = cWellStat(data, p_bud[:,unbudded_cluster], p_rfp, 0.8, DEFAULT_COLNAMES, DEFAULT_FSCA_SLICE)
    I_bud = p_bud[:,unbudded_cluster] > 0.8
    Iby = SP.where(I_bud & (p_rfp[:,0] > 0.8))[0]
    Irm = SP.where(I_bud & (p_rfp[:,1] > 0.8))[0]
    #PL.rcParams.update(PLOT_PARAMS)
    PL.hist(data[Irm, 0], bins=SP.arange(2.5,4.5, 0.1), alpha=0.5, normed=normed)
    PL.hist(data[Iby, 0], bins=SP.arange(2.5,4.5, 0.1), alpha=0.5, normed=normed)
    PL.xlim(2.5,4.5)
    if legend: PL.legend(["BYxBY", "RMxBY"], loc="upper right")
    if yticks is not None: PL.yticks(yticks)
    PL.xlabel("log10(%s-GFP level)"%orf)
    if normed: PL.ylabel("Density")
    else: PL.ylabel("Number of cells")



def plot_gfprfp_hist(data, I, feature="GFP"):
    cols = "bg"
    for i, k in enumerate(["BY unbudded", "RM unbudded"]):
        print i,k, sum(I[k]), cols[i]
        if sum(I[k]) > 0:
            PL.hist(data[I[k]], color=cols[i], alpha=0.5, bins=40)
	    
    PL.xlabel(feature)
    pass


def plot_bar(indices, orfs, parentals, plates, parent_orf_vals, all_plates_ok, bar_width=0.09, linear=False, colors='gmybr', label_sig=True):
    PL.figure(figsize=(17,8))
    vals = SP.zeros([len(indices),len(plates),5])*SP.nan
    for i,idx in enumerate(indices): # for each ORF
        if not all_plates_ok[orfs[idx]]: continue
        for p,parent in enumerate(parentals):
            vals[i,:,p] = parent_orf_vals[parentals[p]][orfs[idx]]
        vals[i,:,4] = vals[i,:,0:3].sum(axis=1) - 2*(vals[i,:,0]) # expected value
        if linear: 
            ev = (10**vals[i,:,0:3]).sum(axis=1) - 2*(10**vals[i,:,0])
            for e in range(len(ev)):
                if ev[e] < 0: ev[e] = -SP.log10(-ev[e])
                else: ev[e] = SP.log10(ev[e]+1)
            vals[i,:,4] = ev
    minval = SP.nanmin(vals)
    if linear: minval = 0
    
    errs = SP.ones(vals.shape[0])*0.002
    ticks = SP.arange(vals.shape[0])
    for r in range(vals.shape[1]):
        for b,bb in enumerate([0,1,2,4,3]): # for each parental and replicate, plot bars (switch last two to have expected before observed)
            PL.bar(ticks + bar_width*2*b + r*bar_width, vals[:,r,bb] - minval, width=bar_width, color=colors[b], alpha=0.6)

    # annotate significance
    labels, sigs = [],[]
    obsexp_diff = vals[:,:,3] - vals[:,:,4]
    mean_plate_sqdiff = calc_mean_plate_sqdiff(vals)
    if linear:
        obsexp_diff = 10**vals[:,:,3] - 10**vals[:,:,4]
        mean_plate_sqdiff = calc_mean_plate_sqdiff(10**vals)
    meandiff = SP.nansum(obsexp_diff,axis=1)/((~SP.isnan(obsexp_diff)).sum(axis=1))
    sddiff = SP.nansum(mean_plate_sqdiff[:,0:4], axis=1)**0.5 + 0.01 # variance = sum of variances of observed; sd = sqrt(sum). Add 0.01 to make sure very close replicates do not mess up
    for i,index in enumerate(indices):
        d,sdd = meandiff[i], meandiff[i]/sddiff[i]
        sig = (abs(sdd) > 1.5) and (abs(d) > 0.05)
        if sig: sigs.append(index)
	if label_sig: labels.append("%s\n%.2f\n%.1f\n%s"%(orfs[index], d, sdd, ["","*"][sig]))
	else: labels.append("%s\n%.2f"%(orfs[index], d))
    PL.xticks(ticks+0.6, labels)
    
    return sigs, list(meandiff)


# Barplots of concordance between additivity and observation
def plot_gfp_bar(all_vals, indices, labels, parentals, bar_width=0.18, colors='gybr', normalise_cis=False):
    PL.figure(figsize=(17,8))
    vals = SP.zeros([len(indices),len(parentals)])*SP.nan
    for i,idx in enumerate(indices): # for each ORF
        for p,parent in enumerate(parentals):
            vals[i,p] = (all_vals[parentals[p]][idx])
        if normalise_cis:
            vals[i,2:4] -= (vals[i,2] - vals[i,1])
    minval = SP.nanmin(vals)
    
    errs = SP.ones(vals.shape[0])*0.002
    ticks = SP.arange(vals.shape[0])
    for b in range(len(parentals)): # for each parental and replicate, plot bars (switch last two to have expected before observed)
        PL.bar(ticks + bar_width*b, vals[:,b] - minval, width=bar_width, color=colors[b], alpha=0.6)

    # annotate significance
    PL.xticks(ticks+0.6, labels)
    PL.legend(parentals, loc="upper left")



def calc_mean_plate_sqdiff(vals):
    diffs = []
    for i in range(vals.shape[1] - 1):
        for j in range(i+1, vals.shape[1]):
            diffs.append((vals[:,i] - vals[:,j])**2)
    diffs = SP.array(diffs)
    return SP.nansum(diffs, axis=0)/((~SP.isnan(diffs)).sum(axis=0)) #	    


# Old plotting device used.
def debug_gfp_stats(filename, mixture_threshold=0.8, rfp_min=10**2, normed=False, hist=True, gfp_rfp_scatter=False, gfp_fsc_scatter=False):
    SP.random.seed(RND_SEED)
    well = get_well96_from_filename(filename)
    r = file_gfp_stats(filename, return_all=True,multifcs=(filename.count("haploid-R1") > 0)) # d is array of events x gfp,rfp,size; skipping too big and too complex cells
    if r is not None:
        (stats, gfp_gmms), d, rfp_gmm = r
    else:
        return
    if gfp_gmms[0] == 0 or gfp_gmms[1] == 0: return
    n_samples = 2 #- filename.count("AB") - filename.count("CD")
    if d.shape[0] < MIN_EVENTS: return # skip empty wells
    p = rfp_gmm.predict_proba(d[:, 1:2]) # datapoint assignments
    I = [SP.where(p[:,j] > mixture_threshold)[0] for j in range(n_samples)] # assignment indicators
    Ibetween = SP.where((d[:,1] > MEAN_RFP_RM) & (d[:,1] < MEAN_RFP_BY))[0]
    frac_unaccounted = 0.
    if len(Ibetween) > 0:
        frac_unaccounted = 1.*len(SP.where(p[Ibetween].max(axis=1) < mixture_threshold)[0])/len(Ibetween) # if fraction of unassigneds between clusters is high, clusters are not well defined
    colors = 'rbk'

    if hist:
        I_clust_mode = []
        PL.figure(figsize=(12,6))
        PL.subplot(1,2,1)
        PL.hist(d[:,1], color="g", bins=SP.arange(2,4,SMOOTH_WINDOW_SIZE), alpha=0.1, normed=normed)
        for j in range(len(I)):
            print len(I[j]), rfp_gmm.means[j,0], rfp_gmm.covars[j][0,0]
            h = PL.hist(d[I[j],1], color=colors[rfp_gmm.means[j,0] < 3], bins=SP.arange(2,4,SMOOTH_WINDOW_SIZE), alpha=0.3, normed=normed)
            #PL.plot([d[I[j],1].mean()],[h[0].max()], colors[j] + ".", markersize=10, alpha=0.3)
            #PL.plot([gmm.means[j]],[h[0].max()], colors[j] + ".", markersize=15, alpha=0.3)
            modehat = h[1][h[0].argmax()] + 0.5*SMOOTH_WINDOW_SIZE
            modeval = h[0].max()
            I_clust_mode.append(SP.where(abs(d[I[j],1] - modehat) < 0.15)[0])
            PL.plot([modehat], [modeval], "b.", markersize=10)
            PL.plot([d[I[j][I_clust_mode[j]],1].mean()], [modeval], "g.", markersize=10)
            PL.text(d[I[j],1].mean(),h[0].max(), "Mean: %.2f\nSD: %.2f\nN: %d"%(rfp_gmm.means[j,0], rfp_gmm.covars[j][0,0]**0.5, len(I[j])))
        PL.title(filename[-40:])
                    
        PL.subplot(1,2,2)
        PL.hist(d[:,0], color="g", bins=SP.arange(2,6,SMOOTH_WINDOW_SIZE), alpha=0.1, normed=normed)
        for j in range(len(I)):
            h = PL.hist(d[I[j],0], color=colors[rfp_gmm.means[j,0] < 3], bins=SP.arange(2,6,SMOOTH_WINDOW_SIZE), alpha=0.3, normed=normed)
            modehat = h[1][h[0].argmax()] + 0.5*SMOOTH_WINDOW_SIZE
            modeval = h[0].max()
            PL.plot([modehat], [modeval], "b.", markersize=10)
            PL.plot([d[I[j][I_clust_mode[j]],0].mean()], [modeval], "g.", markersize=10)
            PL.plot([gfp_gmms[j].means[0,0]], [modeval], "r.", markersize=10)
            PL.text(d[I[j],0].mean(),h[0].max(), "Mean: %.2f\nSD: %.2f\nN: %d"%(gfp_gmms[j].means[0,0], gfp_gmms[j].covars[0][0,0]**0.5, len(I[j])))
        PL.title("%d sample%s"%(n_samples,["","s"][n_samples-1]))
    if gfp_rfp_scatter:
        PL.figure()
        PL.plot(d[:,0],d[:,1], "k.", alpha=0.02, markersize=10)
        for j in range(len(I)):   PL.plot(d[I[j],0],d[I[j],1], colors[j] + ".", alpha=0.1, markersize=10)
        PL.xlabel("GFP")
        PL.ylabel("RFP")
    if gfp_fsc_scatter:
        PL.figure()
        PL.plot(d[:,0],d[:,2], "k.", alpha=0.05, markersize=10)
        for j in range(len(I)): PL.plot(d[I[j],0],d[I[j],2], colors[j] + ".", alpha=0.1, markersize=10)
        PL.xlabel("GFP")
        PL.ylabel("FSC-A")
    mean_ok = [True]*n_samples
    var_ok = [True]*n_samples
    for j in range(n_samples):
        if abs(rfp_gmm.means[j] - MEAN_RFP_BY) > MEAN_TOLERANCE and abs(rfp_gmm.means[j] - MEAN_RFP_RM) > MEAN_TOLERANCE: mean_ok[j] = False
        if rfp_gmm.covars[j][0,0]**0.5 > SD_LIMIT: var_ok[j] = False
    print filename, "Mean ok: ", mean_ok, "Var ok:", var_ok, "Frac unaccounted: %.1f"%(100*frac_unaccounted)
    keys = ['gfp_means','gfp_modes_hist','gfp_modes_mean']
    for k in keys:
        print "%s\t%.2f\t%.2f"%(k, stats[k][0], stats[k][1])
    PL.show()





def plot_file(filename, mixture_threshold=0.8, rfp_min=10**2, random_seed=0, multifcs=False):
    SP.random.seed(random_seed)
    if not multifcs:
        cols, data = read_fcs(filename, fsc_a_max=100000, fsc_w_max=100000, rfp_min=rfp_min, log=True) # array of events x gfp,rfp,size; filtered based on forward scatter size, and width (skipping too big and too complex cells).
    else:
        cols, data = read_multifcs(filename, fsc_a_max=100000, fsc_w_max=100000, rfp_min=rfp_min, log=True)
    #n_samples = 2 - filename.count("AB") - filename.count("CD") # filenames with "AB" or "CD" in them correspond to wells with separate values
    n_samples = 2 - filename.count("heritable-candidate")

    green_col, red_col = cols.index("FITC-A"), cols.index("mCherry-A")
    training_data = data[:,(green_col, red_col)]
    gmm = fit_gmm(training_data, n_samples, random_seed=random_seed)
    p = gmm.predict_proba(training_data[:,1:])
    pdb.set_trace()
    pass


