import os
import sys
import glob
import time
import scipy as SP
import scipy.spatial as SPA
from parts2014_gfpvar.tools.common import *
from parts2014_gfpvar.tools.common import cdm, cl
from parts2014_gfpvar.tools.io import get_array_data

    
""" Get plate, site, image number, row, and column from image filename
@param dat_filename CP output filename, including leading directories
@param img_filename rrrcccsss.tif filename only, no leading info
@param img_count string(image number)
@return plate number, site number, image number, row number, column number
"""
def get_image_metadata(dat_filename, img_filename, img_count):
    plate = int(dat_filename.split("/")[-3].split("_")[-1][1:])
    site = int(dat_filename.split("/")[-2].split("-")[-1])
    row, col = int(img_filename[0:3]), int(img_filename[3:6])
    return plate, site, int(img_count), row, col


""" Attempt to retrieve indices from the CP output headers
@param gfp_header SP.array of GFP column names
@param rfp_header SP.array of RFP column names
@return None, if the required columns do not exist in one of the files, or the output files do not have the expected number of columns; else, indices of required columns
"""
def get_indices(gfp_header, rfp_header):
    gfp_columns = ["Intensity_MeanIntensity_0_Raw_protein_GFP", "AreaShape_Area", "AreaShape_Perimeter"]
    rfp_columns = ["Location_Center_X","Location_Center_Y", "Intensity_MeanIntensity_02_Rescaled_nuclear_RFP","AreaShape_Area"]    
    gfp_I = [SP.where(gfp_header == c)[0][0] for c in gfp_columns if c in gfp_header]
    rfp_I = [SP.where(rfp_header == c)[0][0] for c in rfp_columns if c in rfp_header]
    if (len(gfp_I) != len(gfp_columns)) or (len(rfp_I) != len(rfp_columns)): return None
    return gfp_I, rfp_I


""" Rewrite output from one CP output file to an open tab-delimited format
@param filename CP output file to read from
@param ofh file handle to write to
@param array_data PxRxCx6 array of array contents. First index of 6 is the ORF name """
def output_combined_img_file(filename, ofh, array_data, debug=False):
    (i_gfp, i_area, i_perim), (i_x, i_y, i_rfp, i_nucl) = (0,0,0),(0,0,0,0)
    gfp = SP.loadtxt(filename, dtype=object, delimiter=",")
    rfp = SP.loadtxt(filename.replace("GFP", "RFP"), dtype=object, delimiter=",")
    badsite = (len(gfp) != len(rfp))
    indices = get_indices(gfp[0],rfp[0])
    if indices != None: (i_gfp, i_area, i_perim), (i_x, i_y, i_rfp, i_nucl) = indices
    else: badsite = True
    for i in range(1,min(len(gfp), len(rfp))):
        badsite = badsite or (tuple(gfp[i,0:3]) != tuple(rfp[i,0:3]))
    if badsite: 
        LOG.info("Bad site - %s"%filename)
        return
    for i in range(1, len(gfp)):
        plate, site, image, row, col = get_image_metadata(filename, gfp[i,0], gfp[i,1])
        orf = array_data[plate,row,col,0]
        if debug:
            print orf, plate, site, image, row, col, rfp[i,i_x], rfp[i,i_y], gfp[i,i_gfp], rfp[i,i_rfp], gfp[i,i_area], gfp[i,i_perim], rfp[i,i_nucl]
        ofh.write("%s\t%d\t%d\t%d\t%s%d\t%.3f\t%.3f\t%.3f\t%.3f\t%d\t%.3f\t%d\n"%(orf, plate, site, image, chr(ord('A') - 1 + row), col, float(rfp[i,i_x]), float(rfp[i,i_y]), SP.log10(float(gfp[i,i_gfp])), SP.log10(float(rfp[i,i_rfp])), int(gfp[i,i_area]), float(gfp[i,i_perim]), int(rfp[i,i_nucl])))


""" Output cells from all CP output files for given pattern
@param in_pattern globbing pattern for what files to include in output
@param outfilename output file name
@param n upper limit on number of cells to output for debugging purposes
"""
def output_cells_from_images(in_pattern="%s/img/combined_data/201*/site-?/combined_GFP_features.csv"%DATA_DIR, outfilename="%s/img/summary/paper_cells.tab"%DATA_DIR, n=None, overwrite=False):
    if os.path.exists(outfilename) and (not overwrite): 
        return
    array_data = get_array_data(skip_nan=False) # read metadata
    ofh = file(outfilename, 'w')
    ofh.write("#ORF\tPlate\tSite\tImage\tWell\tX-coordinate\tY-coordinate\tlog10(Average_GFP_pixel_intensity)\tlog10(Average_nuclear_RFP_pixel_intensity)\tArea\tPerimeter\tNuclear_Area\n")
    files = glob.glob(in_pattern)
    n_done = 0
    for f,filename in enumerate(files):
        if (n is not None) and (n_done >= n): continue
        print "%d/%d\t%s"%(f,len(files), "/".join(filename.split("/")[-3:-1]))
        output_combined_img_file(filename, ofh, array_data, debug=False)
        n_done += 1
    ofh.close()


""" Convenience function to combine existing and new data for an ORF
@param data NxK matrix of feature values for N cells
@param ref MxK matrix  of feature values for M cells, or None if no cells already measured
@return (N+M)xK matrix as a concatenation of data and ref; if ref is None, returns data
"""
def combine_well_data(data, ref=None):
    if ref is None: return data
    return SP.concatenate([ref, data - SP.median(data, axis=0) + SP.median(ref,axis=0)]) # force equal medians


""" Combine data from individual cells into strain and cross estimates.
@param filename input file of individual cell measurements
@param limits map of feature->(minval,maxval) to filter on
@param output whether to output strain level summaries
"""
def combine_img_cell_data(filename="%s/img/summary/paper_cells.tab"%DATA_DIR, limits={'Area':(400,3000), 'Perimeter':(60, 430), 'Nuclear_Area':(0,800)}, outfilename="%s/img/summary/combined_cells.pickle"%DATA_DIR, recalc=True):
    print outfilename
    if (not recalc) and os.path.exists(outfilename): return cl(outfilename)
    print "recalculating, since recalc=", recalc
    data = file(filename, 'r').readlines()
    header = data[0].strip().split("\t")
    print header
    limit_idx = [header.index(l) for l in limits]
    res = {}
    orf, plate, well, orfdat = None, None, None, []
    
    for i in range(1, len(data)): # go through datapoints
        d = SP.array(data[i].split("\t"))
        if d[0] != orf: # if new orf
            if len(orfdat) > 0 and (plate is not None): # if have data for previous ORF, store it
                res[(orf,plate,well)] = combine_well_data(SP.array(orfdat), res[(orf,plate,well)]) # combine previous data
            orf, plate, well, orfdat = d[0], d[1], d[4], [] # new current ORF and its data
            if (orf, plate, well) not in res: res[(orf,plate,well)] = None # store the values
        # Already on this ORF, append the line's data if it passes all the limits
        good = True
        for limidx in limit_idx:
            lmin, lmax = limits[header[limidx]]
            good = good and (float(d[limidx]) >= lmin) and (float(d[limidx]) <= lmax)
        if good: orfdat.append(map(float, d[5:]))
    if (orf,plate,well) in res: res[(orf,plate,well)] = combine_well_data(SP.array(orfdat), res[(orf,plate,well)]) # for very last ORF, combine data as well
    if outfilename is not None: cdm(res, outfilename)
    return res



""" Classify N cells into BYxBY or RMxBY cross based on its RFP features. BY query has bright, large nucleus; RM query small and dim. 
@param x NxK feature matrix
@param thresh_rfp log10(RFP level) threshold for classification. RM is below, BY above it.
@param thresh_nucl nuclear area threshold for classification. RM is below, BY above it.
@param min_area optional additional parameter to specify the minimum area of the cell to be considered
@return I_BY, I_RM, where I_x = [a for a in range(len(x)) if a is an instance of x]
"""
def classify_rfp(x, thresh_rfp=-0.95, thresh_nucl=250, min_area=50):
    Iby = SP.where((x[:,4] > min_area) & (x[:,3] > thresh_rfp) & (x[:,6] > thresh_nucl) & (~SP.isnan(x[:,2])))
    Irm = SP.where((x[:,4] > min_area) & (x[:,3] < thresh_rfp) & (x[:,6] < thresh_nucl) & (~SP.isnan(x[:,2])))
    return Iby, Irm


""" Calculate variance estimates for one ORF based on GFP levels and locations of individual cells from BY and RM crosses
@param GFP Nx1 matrix of log10(cell GFP levels)
@param loc Nx2 matrix of x,y coordinates of the cells (in pixels)
@param Iby, Irm - lists of indices for cells classified as BY query cross and RM query cross (see classify_rfp)
@param near_max maximum squared distance between nuclear centres to be considered "Near"
@param far_min minimum squared distance between nuclear centres to be considered "Far" (we want "Far" to be not mother-daughter, but not too far to control for any local effects)
@param far_max maximum squared distance between nuclear centres to be considered "Far"
@return (near_var, far_var, bestreciprocal_var). Each member of the triple is a list of three lists of squared differences between GFP levels. The three lists are BY query, RM query, and any cell pairs. Variances can be estimated as the mean squared difference then.
"""
def calc_var(gfp, loc, Iby, Irm, near_max=60**2, far_min=160**2, far_max=240**2):
    Iby, Irm, N = set(Iby), set(Irm), len(loc)
    near, far, bestreciprocal = [[],[],[]], [[],[],[]], [[],[],[]]
    t2 = time.time()
    dists = SPA.distance_matrix(loc, loc)**2 # use squared distances throughout
    for i in range(len(dists)): dists[i,i] = 1e9 # make self-distance large to avoid finding self as nearest cell
    minsx = SP.argmin(dists, axis=0) # closest other cell to one cell
    gfpvar = SPA.distance_matrix(gfp, gfp)**2 # precalculate cell-cell squared differences in GFP

    for i in range(N-1): # for each pair of cells
        for j in range(i+1, N):
            if SP.isnan(gfpvar[i,j]): continue # if one of the GFP levels not present, skip
            byrm = -1
            if (i in Iby) and (j in Iby): byrm = 0 # if both cells from BYxBY cross, the value will go to first member of list
            if (i in Irm) and (j in Irm): byrm = 1 # if from RMxBY cross, second member
            if dists[i,j] < near_max: # if cells are "near"
                near[2].append(gfpvar[i,j]) # add the squared diff to all pairs
                if byrm >= 0: # if the pair is from same parental background
                    near[byrm].append(gfpvar[i,j]) # add to corresponding pair
                    if (minsx[i] == j) and (minsx[j] == i): # if the cells are each other's closest neighbours
                        bestreciprocal[byrm].append(gfpvar[i,j]) # store this info as well 
            if (dists[i,j] > far_min) and (dists[i,j] < far_max): # else if the cells are far apart, but not too far
                far[2].append(gfpvar[i,j]) # store the values as well
                if byrm >= 0:  far[byrm].append(gfpvar[i,j]) 
    
    return near, far, bestreciprocal


""" Calculate mean and variance statistics for all ORFs
@param combined_data map of (orf, plate, well) -> length-K vector of feature values
@param outfilename file to output tab-delimited data to. If None, no output is created
@return (variances, varcounts, means, meancounts).
    variances: map of (ORF, plate, well)->[[near variance estimates BYxBY log10GFP; RMxBY log10GFP; all], [far variance estimates BY; RM; all], [best reciprocal variance estimates BY; RM; all]]
    varcounts: map of (ORF, plate, well)->[[number of pairs used for near variance estimates BY; RM; all] ...
    means: map of (ORF, plate, well)->[mean(BYxBY log10GFP), median(BYxBY log10GFP), mean(RMxBY log10GFP), median(RMxBY log10GFP), ]
    meancounts: map of (ORF, plate, well)->[number of BYxBY cells, number of RMxBY cells]
"""
def calc_img_orfstats(combined_data, outfilename="%s/img/summary/paper_img_orf_stats.tab"%DATA_DIR, recalc_meanvar=True):
    num_cutoff = 20 # at least 20 cells per parental across the 4 sites
    means, meancounts, variances, varcounts, meanfile, varfile = {},{},{},{}, "%s/img/summary/means.pickle"%DATA_DIR, "%s/img/summary/vars.pickle"%DATA_DIR
    if (not recalc_meanvar) and os.path.exists(meanfile) and os.path.exists(varfile):
        (means,meancounts), (variances, varcounts) = cl(meanfile), cl(varfile)
        if outfilename is not None: save_img_orfstats(variances, varcounts, means, meancounts) # and output tabs
        return variances, varcounts, means, meancounts

    import time
    t = time.time()
    keys = combined_data.keys()
    for m,k in enumerate(keys): # For each orf/plate/well
        if m % 100 == 0:  # output progress periodically 
            print "%d/%d, %.1f elapsed"%(m,len(keys), time.time() - t)
            cdm(variances, "tmp_variances.pickle")
            
        x = combined_data[k] # get all feature values for ORF
        if x is None: continue            

        Iby, Irm = classify_rfp(x) # Classify cells within into BY/RM
        if (len(Iby) == 0) or (len(Iby[0]) < num_cutoff): continue # skip if not enough of one of the parentals
        if (len(Irm) == 0) or (len(Irm[0]) < num_cutoff): continue # skip if not enough of one of the parentals        

        means[k] = (x[Iby[0],2].mean(), SP.median(x[Iby[0],2]), x[Irm[0],2].mean(), SP.median(x[Irm[0],2])) # Calculate mean estimates
        meancounts[k] = (len(Iby[0]), len(Irm[0]))

        near, far, bestreciprocal = calc_var(x[:,2:3], x[:,0:2], Iby[0], Irm[0], near_max=60**2) # Calculate squared GFP differences for cells
        variances[k], varcounts[k] = [], []
        for varsets in [near,far, bestreciprocal]: # for each of the variance estimation sets,
            for v in varsets:
                varcounts[k].append(len(v))
                if len(v) > 2: variances[k].append(SP.median(v)) # if enough pairs to make a robust variance estimate
                else: variances[k].append(SP.nan)
        variances[k] = SP.array(variances[k])
        
    cdm((means,meancounts), meanfile)
    cdm((variances, varcounts), varfile) # pickle mean and variance estimates
    if outfilename is not None: save_img_orfstats(variances, varcounts, means, meancounts) # and output tabs
    return variances, varcounts, means, meancounts


""" Output ORF mean and variance statistics to file
@param variances, varcounts, means, meancounts - see calc_img_orfstats for explanation
@return None
@effect creates output file in img/summary, with one line per ORF with the relevant statistics. Metadata (ORF, plate, well) and BYxBY and RMxBY are output.
"""
def save_img_orfstats(variances, varcounts, means, meancounts, outfilename="%s/img/summary/paper_orfstats.tab"%DATA_DIR):
    ofh = file(outfilename, 'w')
    ofh.write("#ORF     \tPlate\tWell\tBYxBY_mean\tBYxBY_median\tBYxBY_ncells")
    ofh.write("\tBYxBY_bestrec_var\tBYxBY_bestrec_npairs\tBYxBY_near_var\tBYxBY_near_npairs\tBYxBY_far_var\tBYxBY_far_npairs")
    ofh.write("\tRMxBY_mean\tRMxBY_median\tRMxBY_ncells")
    ofh.write("\tRMxBY_bestrec_var\tRMxBY_bestrec_npairs\tRMxBY_near_var\tRMxBY_near_npairs\tRMxBY_far_var\tRMxBY_far_npairs")
    ofh.write("\tAny_far_var\tAny_far_npairs\n")

    for k in variances: # for each orf
        ofh.write("%s\t%s\t%s"%(tuple(k))) # orf, plate, well
        for p,parent in enumerate(["BY", "RM"]):
            ofh.write("\t%.4e\t%.4e\t%d"%(means[k][2*p], means[k][2*p + 1], meancounts[k][p])) # means
            for i in [2,0,1]: # best reciprocal, near, far for BY and RM
                ofh.write("\t%.4e\t%d"%(variances[k][3*i + p], varcounts[k][3*i + p])) # best reciprocal variance
        ofh.write("\t%.4e\t%d\n"%(variances[k][-4], varcounts[k][-4])) # Last three are bestreciprocal; the three before are the 'far' ones. The last of those is 'any pair'
    ofh.close()
    

def main():
    pass


if __name__ == '__main__':
    main()
