import scipy as SP
import glob
import os
import time
import pdb
from parts2014_gfpvar.tools.common import *
from parts2014_gfpvar.analysis.img.cp_io import *
from parts2014_gfpvar.analysis.img.util import *
from parts2014_gfpvar.analysis.img.quantify import *


""" Classify cells from a screen using precomputed classifers
@param experiment name of the experiment
@param screen name of the screen within the experiment
@param classifier identifier of the classifier to use
@require exists(<base>/training/classifiers/<classifier>.pickle)
@param threshold required posterior probability of a class to call a localisation pattern from it
@param n_min_cells minimum number of cells from one parent in a well to make a call
@param outfilename file name to store classification results in
@param overwrite if False, prediction is loaded from precomputed results; requires outfilename to exist
"""
def classify(experiment="Pilot_screen_BYxRM", screen="all-h2", classifier="trainingset-small_trees-1000", threshold=0.5, n_min_cells=30, outfilename=None, overwrite=False):
    LOG.debug("Classifying screen %s with classifier %s (threshold %.2f, at least %d cells per well)"%(screen, classifier, threshold, n_min_cells))
    # 0. set up data
    basedir = "%s/img"%DATA_DIR
    outfilename = '%s/combined/Pilot_screen_BYxRM/all-h2/prediction.pickle'%basedir # TEMPORARY OUTPUT
    data, meta, features, meta_features = read_combined_cp_matrix(screen)

    # 1. get classifiers (map of class->classifier)
    classifiers = cl("%s/training/classifiers/%s.pickle"%(basedir, classifier))
    classes = classifiers.keys()
    prediction = SP.zeros([len(data), len(classes)])

    # 2. classify for each class
    if (not overwrite) and os.path.exists(outfilename): # if no need to recompute
        (classes, prediction) = cl(outfilename)         # load results and proceed
    else:
        for i,c in enumerate(classes):                  # else, for each class
            LOG.debug("Classifying class %s (%d/%d)"%(c, i+1, len(classes)))
            classifier = classifiers[c]                 # get the corresponding classifier
            t0 = time.time()
            prediction[:,i] = classifier.predict(data)  # and apply it on data
            LOG.debug("...done, t=%.1f"%(time.time() - t0))
            
        if outfilename is not None:                     # finally, store results if requested
            cdm((classes,prediction), outfilename)

    # 3. get parental assignment, and compare distributions
    Ip = cluster_parent(data, features)
    wells = sorted(SP.unique(meta[:,3]))
    absdiffs = []
    allfracs = []
    allmintotals = []
    allpreds = []
    print "\t"
    for c in classes: print "\t%s"%c[0:15],
    print
    for w in wells: # for each well
        fracs,totals = {},[]
        for parent in Ip: # for each parent
            I = SP.where((meta[:,3] == w) & Ip[parent])[0] # get cells from the parent in the well
            fracs[parent] = (prediction[I,:] > threshold).mean(axis=0) # store fraction of each class
            totals.append(len(I)) # and total number of cells
        allfracs.append(fracs)
        allmintotals.append(min(totals))
        allpreds.append(prediction[I,:])
        if min(totals) > n_min_cells: # if enough cells, store difference and fractions in general
            absdiffs.append(abs(fracs["BY"] - fracs["RM"]).sum())
            if absdiffs[-1] > 0.5:
                print w, totals, 
                for i in range(len(classes)):
                    print "\t%02d/%02d"%(100*fracs["BY"][i], 100*fracs["RM"][i]),
                print
        else: # if not, store dummies
            absdiffs.append(SP.nan)

    import pylab as PL
    absdiffs = SP.array(absdiffs)
    PL.hist(absdiffs[SP.where(~SP.isnan(absdiffs))[0]], bins=40), PL.show()
    pdb.set_trace()
    f = allfracs[-3]
    p = allpreds[-3]
    pass
