import scipy as SP
import pylab as PL
import scipy.stats as ST
from parts2014_gfpvar.analysis.cyto.cluster import *
from parts2014_gfpvar.tools.common import *
from parts2014_gfpvar.analysis.cyto.quantify import *

DEFAULT_GFP_FEATURES = ["cell_Intensity_IntegratedIntensity_0_Raw_protein_GFP",
                        "cell_Intensity_MeanIntensity_0_Raw_protein_GFP",
                        "cell_Intensity_MedianIntensity_0_Raw_protein_GFP",
                        "gfpobj_Intensity_IntegratedIntensity_0_Raw_protein_GFP_median",
                        "gfpobj_Intensity_MaxIntensity_0_Raw_protein_GFP_median",
                        "gfpobj_Intensity_MeanIntensity_0_Raw_protein_GFP_median",
                        "gfpobj_Intensity_MedianIntensity_0_Raw_protein_GFP_median"]


""" Calculate quantitative GFP traits 
@param data n_cells x n_features matrix of floating point values
@param data_features list of all feature names for the data
@param meta metadata of the cells
@require meta[:,2] is the plate, meta[:,3] is the well of the cell
@param gfp_features list of GFP feature names to use
@return map of plate,well -> parent -> gfp feature -> values
"""
def quantify_gfp(data, data_features, meta, gfp_features=DEFAULT_GFP_FEATURES, min_well_cells=0):
    # 0. get GFP data (log scale)
    Igfp = [data_features.index(f) for f in gfp_features]
    d_gfp = SP.log10(data[:,Igfp]) # get total log scale GFP
    # 1. cluster cells to parents
    Ip = cluster_parent(data, data_features, debug=True)

    # 2. quantify contents of each well
    result = {}    
    for p in SP.unique(meta[:,2]): # for each plate
        LOG.debug("Quantifying plate %s"%p)
        for w in SP.unique(meta[:,3]):  # for each well
            Iw = ((meta[:,2] == p) & (meta[:,3] == w)) # take the cells in that well
            if sum(Iw) == 0: continue # (skipping empty wells)
            print "%s - %d cells"%(w, sum(Iw)),
            well_res = {}
            for parent in Ip: # for BY and RM parent
                I = SP.where(Ip[parent] & Iw)[0] # take cells in the well from the parent
                well_res[parent] = {'count':len(I)} # store the GFP feature values
                well_res[parent]['mean'] = SP.mean(d_gfp[I], axis=0, skipna=True)
                well_res[parent]['median'] = SP.median(d_gfp[I], axis=0)
                well_res[parent]['var'] = SP.var(d_gfp[I], axis=0, skipna=True)
            if (len(well_res) < 2) or min(well_res['RM']['count'], well_res['BY']['count']) < min_well_cells:
                pass # do not look at wells that have incomplete data
            else:
                print "BY: %d; RM: %d"%(well_res['BY']['count'], well_res['RM']['count'])
                result[(p,w)] = well_res
    print Ip['BY'].sum(), Ip["RM"].sum()
    pdb.set_trace()            
    return result


""" Cluster data based on RFP features to get BY and BYxRM cells
@param data array of data
@param data_features list of all data feature names 
@param clustering_features list of feature names used in clustering
@param threshold minimum posterior probability for a cell to be classified to a cluster
@return Ip: map of parent to binary indicator vector that says for each cell, whether it is from the given parental background
"""
def cluster_parent(data, data_features, clustering_features=("nucleus_AreaShape_Area", "nucleus_Intensity_IntegratedIntensity_02_Rescaled_nuclear_RFP"), threshold=0.8, debug=False):
    LOG.debug("Clustering %d cells from image data based on features: %s; threshold=%.2f"%(data.shape[0], " ".join(clustering_features), threshold))
    Irfp = [data_features.index(f) for f in clustering_features]
    d = data[:,Irfp]
    d[:,1] = SP.log10(d[:,1]) # put RFP intensity to log scale
    Inan = SP.isnan(d).any(axis=1)
    rfp_clustering = fit_gmm(d[SP.where(~Inan)[0]], n_components=3, features=[0,1], prior_means=None, random_seed=1, n_restarts=20)
    p = rfp_clustering.predict_proba(d)
    p[SP.where(Inan)[0], :] = 0

    i_by, i_rm = 0,1
    if rfp_clustering._means[0][0] < rfp_clustering._means[1][0]: # if cluster 0 is smaller and dimmer than cluster 1
        i_by, i_rm = 1,0 # switch BY, RM labels

    result = {'BY':p[:,i_by] > threshold, 'RM':p[:,i_rm] > threshold}

    if debug:
        I1,I2 = SP.where(result["BY"])[0], SP.where(result["RM"])[0]
        I0 = SP.where(~(result["BY"] | result["RM"]))[0]
        PL.plot(d[I0,0], d[I0,1], "k.", alpha=0.1, markersize=10)
        PL.plot(d[I1,0], d[I1,1], "r.", alpha=0.2, markersize=10)
        PL.plot(d[I2,0], d[I2,1], "b.", alpha=0.2, markersize=10)
        PL.show()

    return result



def res_to_mat(res):
    d = SP.zeros([4, 2, len(res), len(DEFAULT_GFP_FEATURES)])
    wells = res.keys()
    for t,trait in enumerate(['mean','median','var', 'count']):
        for i,well in enumerate(wells):
            for p,parent in enumerate(["BY","RM"]):
                d[t,p,i,:] = res[well][parent][trait]
    return d,wells


def debug_features(d):
    n_feat = len(DEFAULT_GFP_FEATURES)
    for t in range(n_feat - (n_feat - 1)*(trait == "count")):
        PL.plot(d[0,:,t], d[1,:,t], ".", markersize=12, alpha=0.3)
        PL.xlabel("BY")
        PL.ylabel("RM")
        if trait == "count":
            PL.title("Count")
        else:
            PL.title("%s - %s"%(trait, DEFAULT_GFP_FEATURES[t]))
        PL.show()
    

"""
@param data array of [summary stat, parent, well, feature]
@param wells list of wells to match data
"""
def compare_img_cyto(data, wells, cyto_screen="all-h2", cyto_plate="all-h2-Haploid-R1"):
    summaries = ['mean','median','stddev']
    # 0. get cytometry data
    cyto_vals = read_screen_results(cyto_screen).plates[cyto_plate].orf_stats.values()
    PL.rcParams.update({'text.fontsize':40, 'xtick.labelsize':24, 'ytick.labelsize':24, 'text.size':40, 'axes.titlesize':30, 'axes.labelsize':26, 'figure.figsize':(16,10)})
    cyto_res = {} # well -> slice -> parent -> values
    for v in cyto_vals:
        st = get_screen_orf_summary(v, ["Haploid"])
        for w in st:  cyto_res[w] = st[w]
    d_cyto = SP.zeros([3,2,len(wells)])*SP.nan
    for w in range(len(wells)):
        for p,parent in enumerate(["BY","RM"]):
            d_cyto[:, p, w] = cyto_res[wells[w][1]]['slice'][parent][0:3]

    # 1. compare absolute levels
    for s in range(0):#len(summaries)): # for each summary stat
        for f in range(3):#data.shape[3]): # for each summarised feature
            #print summaries[s], f, data[s,:,:,f].shape, data[s,:,:,f].mean(axis=0), d_cyto[s,:,:].mean(axis=0)
            x = SP.mean(data[s, :, :, f], axis=0, skipna=True)
            y = SP.mean(d_cyto[s,:,:], axis=0, skipna=True)
            if summaries[s] == "stddev": x,y = x**0.5, y**0.5
            PL.plot(x, y, ".", alpha=0.5, markersize=15)
            v_min = min(x.min(), y.min())
            v_max = max(x.max(), y.max())
            PL.plot([v_min,v_max],[v_min, v_max], 'r--')
            PL.xlabel("%s(GFP feature) from microscopy"%(summaries[s]))
            PL.ylabel("%s(GFP) from cytometry"%(summaries[s]))
            PL.title("GFP feature %s (r=%.2f)"%(DEFAULT_GFP_FEATURES[f], ST.spearmanr(x,y)[0]))
            PL.show()

    # 2. compare differences
    for s in range(0,1):#,len(summaries)): # for each summary stat
        for f in range(3):#data.shape[3]): # for each summarised feature
            x = data[s,0, :, f]-data[s,1, :, f]
            y = d_cyto[s,0,:] - d_cyto[s,1,:]
            if summaries[s] == "stddev": x,y = data[s,0, :, f]**0.5-data[s,1, :, f]**0.5, d_cyto[s,0,:]**0.5 - d_cyto[s,1,:]**0.5
            PL.plot(x, y, ".", alpha=0.5, markersize=15)
            v_min = min(x.min(), y.min())
            v_max = max(x.max(), y.max())
            PL.plot([v_min,v_max],[v_min, v_max], 'r--', linewidth=2)
            PL.xlabel("%s(BYxBY) - %s(BYxRM) from microscopy"%(summaries[s],summaries[s]))
            PL.ylabel("%s(BYxBY) - %s(BYxRM) from cytometry"%(summaries[s],summaries[s]))
            PL.title("GFP feature %s (r=%.2f)"%(DEFAULT_GFP_FEATURES[f], ST.spearmanr(x,y)[0]))
            #PL.xlim(-0.3,0.3)
            #PL.ylim(-0.3,0.3)
            PL.xlim(-0.3,0.3)
            PL.ylim(-0.3,0.3)
            PL.show()
