import scipy as SP
import pylab as PL
import sys
from util import *
from fcs import *
from parts2014_gfpvar.tools.sklearn_ext import GMM
from parts2014_gfpvar.tools.common import *

#MIN_EVENTS = 15000
#MEAN_RFP_BY = 3.45
#MEAN_RFP_RM = 2.8
#SD_LIMIT = 0.45
#UNACCOUNTED_LIMIT = 0.5
#MEAN_INITS = [[0, SP.array([[2.4],[2.4]]), SP.array([[2.4],[2.4],[2.4]])], [0,SP.array([[0.5*(MEAN_RFP_BY+MEAN_RFP_RM)],[1]]), SP.array([[MEAN_RFP_BY],[MEAN_RFP_RM],[1]])]]
#MEAN_TOLERANCE = 0.3
#SMOOTH_WINDOW_SIZE = 0.05
#RESULT_KEYS_SINGLE = ['unaccounted']
#RESULT_KEYS_DOUBLE = "rfp_means rfp_vars rfp_modes gfp_means gfp_vars gfp_modes_hist gfp_modes_mean counts".split(" ")



class cClusteringSettings:
    def __init__(self, prior="data", n_restarts=20, points=100000, threshold=0.80, fscasclicesize=0.10, n_rfp_modes=3, n_rnd_files=48):
        self.prior = prior
        self.n_restarts = n_restarts
        self.n_rnd_points = points
        self.threshold = threshold
        self.fsca_slice_size = fscasclicesize
        self.n_rfp_modes = n_rfp_modes
        self.n_rnd_files = n_rnd_files
        self.kvpairs = {"prior":prior, "restarts":n_restarts, "points":points, "threshold":threshold, "fscasclicesize":fscasclicesize, "rfpmodes":n_rfp_modes, "rndfiles":n_rnd_files}

    # Read option values from filename string, such as stats_prior-data_restarts-20_points-100000_threshold-0.80_fscasclicesize-0.10
    def read_from_string(self, s):
        
        keyval = s.split("_")
        self.kvpairs = {}
        
        for kv in keyval[1:]:
            key,val = keyval.split("-")
            self.kvpairs[key] = val
            if key == "prior": self.prior = val
            elif key == "restarts": self.n_restarts = int(val)
            elif key in ["points", "n_rnd_points"]: self.n_rnd_points = int(val)
            elif key == "threshold": self.threshold = float(val)
            elif key == "fscasclicesize": self.fsca_slice_size = float(val)
            elif key == "rndfiles": self.n_rnd_files = int(val)
            elif key == "rfpmodes": self.n_rfp_modes = int(val)


    def __str__(self):
        return "_".join(["-".join(map(str,kv)) for kv in self.kvpairs.items() if kv[0] != "filters"])




"""
Take data d (n_samples x n_features, with GMM fit only on feature 1, and fit a Gaussian mixture of n_samples + 1 components
Return fit GMM object, posterior probabilities of cluster indicators, and lists of cluster assignments
"""
def fit_gmm(data, n_components, features=[1], prior_means=None, random_seed=RND_SEED, cvtype='diag', n_iter=10, thresh=0.1, n_noise_components=1, n_restarts=1):
    best_logprob = -1e20
    gmm = None

    for i in range(n_restarts): # for each random restart
        LOG.debug("Restart %d of %d; current best logprob=%.1f"%(i+1, n_restarts, best_logprob))
        _gmm = GMM(n_components=n_components, min_covar=SP.outer([0.01]*(n_components - n_noise_components) + [1]*n_noise_components, SP.ones([1,1])), random_state=random_seed+i, cvtype=cvtype) # min_covar is a kx1 vector with all 0.001 except for one large entry
        if prior_means is not None: # if prior means given, use them
            _gmm._set_means(prior_means)
            #        elif features in MEAN_INITS: # else if particular configuration foreseen, use that
            #_gmm._set_means(MEAN_INITS[features][n_components-n_noise_components]) # otherise initialise means to be our prior knowledge about where they ought to be
        _gmm.fit(data[:, features], n_iter=n_iter, thresh=thresh) # fit a k-means mixture for the feature (default=RFP level). One component per sample, plus one more for all the outliers
        
        logprob = _gmm.eval(data[:, features])[0].sum() # calculate data log-likelihood
        if logprob > best_logprob: # if better than current best model, pick it as the one to return
            gmm = _gmm
            best_logprob=logprob
    return gmm





""" Create relevant clusters based on data from entire plate. """ 
def create_plate_clusters(data=None, n_clusters=5, n_noise_clusters=1, prior_means=None, n_restarts=5, colnames=DEFAULT_COLNAMES, cluster_features=DEFAULT_CLUSTER_FEATURES, debug=False):
    if n_noise_clusters == n_clusters: n_noise_clusters = 0 # if don't actually want to cluster RFP, just have one cluster for all.
    # 0. preprocess data
    Ifeatures = [colnames.index(f) for f in cluster_features] # get features for clustering
    means,sds = data.mean(axis=0), (data.var(axis=0)**0.5) # standardise data
    data = (data-means)/sds
    prior_means = (prior_means - means[Ifeatures])/(sds[Ifeatures]) # and prior

    # 1. fit mixed model
    gmm = fit_gmm(data, n_clusters, Ifeatures, prior_means, cvtype='full', n_iter=100, thresh=0.01,n_noise_components=n_noise_clusters, n_restarts=n_restarts)
    
    if not debug: return gmm
    
    # Extra - debug clustering by plotting the results
    import pylab as PL
    p = gmm.predict_proba(data[:,Ifeatures]) # prediction on data
    colors = 'rgbyk'
    threshold = 0.6
    PL.figure()
    for i in range(n_clusters):
        I = SP.where(p[:,i] > threshold)[0]
        PL.plot(data[I,2],data[I,3], ".", markersize=12, alpha=0.02, color=colors[i])
    I = SP.where(p.max(axis=1) < threshold)[0]
    PL.plot(data[I,2],data[I,3], "k.", markersize=12, alpha=0.01)        
    PL.figure()
    PL.plot(data[:,2],data[:,3], "k.", markersize=12, alpha=0.01)
    if len(cluster_features) > 2:
        PL.figure()
        for i in range(n_clusters):
            I = SP.where(p.argmax(axis=1) == i)[0]
            PL.plot(data[I,1],data[I,3], ".", markersize=12, alpha=0.02, color=colors[i])            
    PL.show()
    pdb.set_trace()
    return gmm




def create_byxrm_plate_clusters_twophase(data, colnames=DEFAULT_COLNAMES, prior="data", n_restarts=10, n_rnd_points=100000, debug=False, outfile=None, overwrite=True, threshold=0.8, n_rfp_modes=2, filters=DEFAULT_FILTERS, plot=False):
    
    if (outfile is not None) and os.path.exists(outfile) and (not overwrite): # if already have clustering, update with filters, and plot if desired
        r = cl(outfile)
        if 'filters' not in r:
            r['filters'] = filters
            r['fsca_slice'] = None
            cdm(r, outfile) # update and save filters
        if plot:
            gmm, rfp_gmms, means, sds, colnames = r['budding_GMM'], r['rfp_GMMs'], r['means'], r['sds'], r['features']
            data = (data-means)/sds
            Ifeature = [colnames.index("FSC-A"), colnames.index("SSC-W")]
            p = gmm.predict_proba(data[:,Ifeature]) # first, cluster budded/unbudded
            plot_subclustering(data, p, rfp_gmms, c1=1, c2=3, threshold=threshold, colnames=colnames, n_rnd_points=n_rnd_points) # scatter plot of clusters
            PL.savefig(outfile.replace(".pickle",".pdf"))        
        return r

    # 0. standardise data
    means,sds = data.mean(axis=0), (data.var(axis=0)**0.5)
    data = (data-means)/sds

    # 1. separate based on budding
    Ifeature = [colnames.index("FSC-A"), colnames.index("SSC-W")]
    gmm = create_plate_budding_clusters(data=data[:,Ifeature], prior=prior, n_restarts=n_restarts, colnames=["FSC-A", "SSC-W"])

    # 2. separate each cluster further based on RFP
    p = gmm.predict_proba(data[:,Ifeature]) # first, cluster budded/unbudded
    rfp_gmms = []
    for c,title in enumerate(['budded','unbudded']): # for budded and unbudded clusters
        I = SP.where(p[:,c] > threshold)[0] # take data in the cluster, and create subclusters based on RFP. This has proved to be more robust than 4-way clustering right away
        rfp_gmms.append(create_plate_clusters(data[I], n_clusters=n_rfp_modes+1, n_noise_clusters=1, prior_means=DEFAULT_PRIOR_RFP_MEANS, n_restarts=n_restarts, colnames=colnames, cluster_features=DEFAULT_RFP_CLUSTER_FEATURES))

    # 3. create result structure, plot QC plot of the clusters
    result = {'budding_GMM':gmm, 'rfp_GMMs': rfp_gmms, 'means':means, 'sds':sds, 'features':colnames}
    plot_subclustering(data, p, rfp_gmms, c1=1, c2=3, threshold=threshold, colnames=colnames, n_rnd_points=n_rnd_points) # scatter plot of clusters
    if debug:
        PL.show()
        pdb.set_trace()

    # 4. save result if desired, return
    if outfile is not None:
        outdir = "/".join(outfile.split("/")[0:-1])
        if not os.path.exists(outdir): os.system("mkdir -p %s"%outdir)
        cdm(result, outfile)
        PL.savefig(outfile.replace(".pickle",".pdf"))
    return result


""" Plot data from a 4-way clustering given 2-way cluster probabilities p, and two further mixtures trained on the clusters  """ 
def plot_subclustering(data, p, rfp_gmms, c1, c2, threshold=0.8, colnames=DEFAULT_COLNAMES, n_rnd_points=100000):
    fsca_slice_size = 0.1
    f,s,m = colnames.index("FSC-A"), colnames.index("SSC-W"), colnames.index("mCherry-A")
    I = SP.where(p[:,1] > threshold)[0] # take data in the cluster
    fsca_sorted = sorted(data[I,f])
    n = len(I)
    fsca_slice = [fsca_sorted[n/2 - int(n*0.5*fsca_slice_size)], fsca_sorted[n/2 + int(n*0.5*fsca_slice_size)]]
    PL.figure(figsize=(14,10))
    PL.subplot(221)
    for c in range(2): # for budded and unbudded clusters
        I = SP.where(p[:,c] > threshold)[0] # take data in the cluster
        p_rfp = rfp_gmms[c].predict_proba(data[I][:,1:2]) # cluster it further
        for j in range(min(2, p_rfp.shape[1])):
            Ir = SP.where(p_rfp[:,j] > threshold)[0] # plot each of the two clusters
            PL.plot(data[I,c1][Ir],data[I,c2][Ir], ".", markersize=12, alpha=0.02)
        Ir = SP.where(p_rfp.max(axis=1) < threshold)[0] # unassigned data points
        PL.plot(data[I,c1][Ir],data[I,c2][Ir], "k.", markersize=12, alpha=0.02)
    PL.ylabel("SSC-A")
    PL.xlabel("mCherry-A")
    PL.subplot(222)
    for c in range(2):
        I = SP.where(p[:,c] > threshold)[0] # take data in the unbudded cluster 
        p_rfp = rfp_gmms[c].predict_proba(data[I][:,1:2]) # cluster it further
        for j in range(min(2, p_rfp.shape[1])):
            Ir = SP.where(p_rfp[:,j] > threshold)[0] # plot each of the two clusters
            if len(I) > 0 and len(Ir) > 0:
                PL.hist(data[I,c1][Ir], bins=30, alpha=0.3)
    PL.xlabel("mCherry-A")
    PL.title("All cells")
    PL.subplot(223)
    unaccounted = []
    for c in range(2): # for budded and unbudded clusters
        I = SP.where(p[:,c] > threshold)[0] # take data in the cluster
        p_rfp = rfp_gmms[c].predict_proba(data[I][:,1:2]) # cluster it further
        for j in range(min(2, p_rfp.shape[1])):
            Ir = SP.where(p_rfp[:,j] > threshold)[0] # plot each of the two clusters
            PL.plot(data[I,c1][Ir],data[I,f][Ir], ".", markersize=12, alpha=0.02)
        Ir = SP.where(p_rfp.max(axis=1) < threshold)[0] # unassigned data points
        PL.plot(data[I,c1][Ir],data[I,f][Ir], "k.", markersize=12, alpha=0.02)
        If = SP.where((data[I,f] >= fsca_slice[0]) & (data[I,f] <= fsca_slice[1]))[0]
        unaccounted.append(100*sum(p_rfp[If,0:2].max(axis=1) < threshold)/len(If))
    (x1,x2), (y1,y2) = (min(data[:,c1]), max(data[:,c1])), fsca_slice
    PL.plot([x1,x2],[y1,y1],'r-')
    PL.plot([x1,x2],[y2,y2],'r-')
    PL.plot([x1,x1],[y1,y2],'r-')
    PL.plot([x2,x2],[y1,y2],'r-')
    PL.ylabel("FSC-A")
    PL.xlabel("mCherry-A")
    PL.title("Unaccounted: blue-green: %d%% red-teal: %d%%"%(tuple(unaccounted)))
    PL.subplot(224)
    for c in range(2):
        I = SP.where(p[:,c] > threshold)[0] # take data in the unbudded cluster 
        p_rfp = rfp_gmms[c].predict_proba(data[I][:,1:2]) # cluster it further
        for j in range(min(2, p_rfp.shape[1])):
            Ir = SP.where(p_rfp[:,j] > threshold)[0] # plot each of the two clusters
            if len(I) > 0 and len(Ir) > 0:
                d = data[I,:][Ir]
                If = SP.where((d[:,f] >= fsca_slice[0]) & (d[:,f] <= fsca_slice[1]))[0]
                if len(If) > 0: PL.hist(d[If,m], bins=30, alpha=0.3)
    PL.xlabel("mCherry-A")
    PL.title("FSC-A slice %.3f to %.3f"%(tuple(fsca_slice)))
    


""" Cluster data based on the budding pattern into unbudded (low SSC-W), budded (high SSC-W), and noise (not in first two clusters).
Data is assumed to be in log scale.
"""
def create_plate_budding_clusters(colnames=["FSC-A", "SSC-W"], data=None, prior="data", n_restarts=10):
    prior_means = SP.zeros([3,len(colnames)]) # clusters for budded, unbudded, noise

    if prior in ["data", "manual"]: # estimate prior from data
        n_points = data.shape[0]
        I = SP.argsort(data[:,colnames.index("SSC-W")])  # read data, and split in half based on SSC-W 
        prior_means[0] = data[I[n_points/2:],:].mean(axis=0) # high SSC-W (budded)
        prior_means[1] = data[I[0:n_points/2],:].mean(axis=0) # low SSC-W (unbudded)
        prior_means[2] = data.mean(axis=0) # Noise
    if prior == "manual":
        import pylab as PL
        PL.figure(figsize=(12,10))
        prior_means[0] = [-0.46, -1.24]
        prior_means[1] = [0.2, 0.73]
        prior_means[2] = [3,3]
        pm = SP.array(prior_means)
        print "Write down cluster centeres for unbudded (low SSC), budded (high SSC), and noise (default ", pm, ")"
        I = SP.random.choice(range(len(data)), 3000, replace=False)
        PL.plot(data[I,0], data[I,1], ".", markersize=14, alpha=0.15)
        PL.plot(pm[:,0], pm[:,1], "r.", markersize=14)
        PL.xlabel(colnames[0])
        PL.ylabel(colnames[1])
        PL.show()
        print "reeneter (Y/N)?",
        answer = sys.stdin.readline().strip()
        if answer.lower() == "y":
            print "Unbudded:",       
            prior_means[0] = map(float, sys.stdin.readline().strip().split())
            print "Budded:",
            prior_means[1] = map(float, sys.stdin.readline().strip().split())
            print "Noise:",
            prior_means[2] = map(float, sys.stdin.readline().strip().split())
    if prior not in ["data", "manual"]: # else use prior guess for clusters
        prior_means[0] = [SP.log10(BUDDED_FSCA_MEAN), SP.log10(BUDDED_SSCW_MEAN)] # BY budded
        prior_means[1] = [SP.log10(UNBUDDED_FSCA_MEAN), SP.log10(UNBUDDED_SSCW_MEAN)] # BY unbudded
        prior_means[2] = [SP.log10(UNBUDDED_FSCA_MEAN), SP.log10(UNBUDDED_SSCW_MEAN)] # Noise

    return create_plate_clusters(data, n_clusters=3, n_noise_clusters=1, prior_means=prior_means, n_restarts=n_restarts, colnames=colnames, cluster_features=colnames)
    


def read_clustering(experiment, screen, plate):
    statdir = glob.glob("%s/cytometry/%s/%s/stats*"%(DATA_DIR, experiment, screen))[0]
    return cl("%s/cluster-gmm_plate-%s.pickle"%(statdir, plate))




def read_clustered_well_data(experiment="Pilot_screen_BYxRM", screen=None, plate=None, well=None, debug=False, threshold=0.8):    
    filename = glob.glob("/Volumes/BACKUP/%s/%s/%s/*_%s_*.fcs"%(experiment, screen, plate, well))[0]
    filters = DEFAULT_FILTERS

    # 0. read, filter, and standardise data
    clustering = read_clustering("Pilot_screen_BYxRM", screen, plate)  # read the clustering used for generating these stats
    cols, data = read_fcs(filename, colnames_tostore=clustering['features'], log=True)
    if 'filters' in clustering: data = filter_fcs(data, cols, clustering['filters'])
    else: data = filter_fcs(data, cols, DEFAULT_FILTERS)

    if data.shape[0] < 100:
        LOG.info("Low number of datapoints for file %s"%filename)
    Icol = [cols.index(c) for c in clustering['features']]
    data_std = (data[:,Icol] - clustering['means'])/clustering['sds']

    # 1. cluster data for budding index
    budding_gmm, rfp_gmms = clustering['budding_GMM'], clustering['rfp_GMMs']
    Icluster = [clustering['features'].index(feature) for feature in DEFAULT_BUDDING_CLUSTER_FEATURES]
    p_bud = budding_gmm.predict_proba(data_std[:,Icluster]) 

    # 2. cluster unbudded cells further based on RFP
    unbudded_cluster = budding_gmm._means[0:2,1].argmin() # unbudded cluster is the one of the first two that has smaller mean SSC-W (first index)
    p_rfp = rfp_gmms[unbudded_cluster].predict_proba(data_std[:, [clustering['features'].index('mCherry-A')]]) # cluster cells based on RFP
    if (rfp_gmms[unbudded_cluster].n_components > 2) and (rfp_gmms[unbudded_cluster]._means[0] > rfp_gmms[unbudded_cluster]._means[1]): # if clusters in wrong order, reorder assignments
        p_rfp = p_rfp[:,[1,0,2]] # switch first two clusters 

    # 3. Assign cells to cluster
    I = {}
    I["RM unbudded"] = ((p_bud[:,unbudded_cluster] > threshold) & (p_rfp[:,0] > threshold)) # p_rfp has low RFP cluster first, corresponding to RM
    I["RM budded"]   = ((p_bud[:,unbudded_cluster] <= threshold) & (p_rfp[:,0] > threshold))
    I["BY unbudded"] = ((p_bud[:,unbudded_cluster] > threshold) & (p_rfp[:,1] > threshold))
    I["BY budded"]   = ((p_bud[:,unbudded_cluster] <= threshold) & (p_rfp[:,1] > threshold))
    

    # 4. calculate and well statistics based on the clusterings if required
    if debug:
        cWellStat(data, p_bud[:,unbudded_cluster], p_rfp, threshold, DEFAULT_COLNAMES, debug=True)


    return I, data, SP.array(DEFAULT_COLNAMES)
