import scipy as SP
import numpy as NP
import glob
import sys
from StringIO import StringIO
import struct
from parts2014_gfpvar.tools.common import *

RND_SEED = 1
DEFAULT_COLNAMES = ["FITC-A","mCherry-A","FSC-A", "SSC-W"]
DEFAULT_CLUSTER_FEATURES = ["mCherry-A","FSC-A", "SSC-W"]
DEFAULT_BUDDING_CLUSTER_FEATURES = ["FSC-A", "SSC-W"]
DEFAULT_RFP_CLUSTER_FEATURES = ["mCherry-A"]
#DEFAULT_FILTERS = [("FSC-A", SP.log10(5000),SP.log10(90000)), ("SSC-W", SP.log10(50000), SP.log10(95000))]
#DEFAULT_FILTERS = [("FSC-A", 4.2, 5.1), ("SSC-W", 4.8,4.95), ("mCherry-A", 2.25, 4)] # R3
#DEFAULT_FILTERS = [("FSC-A", 4.2, 5.35), ("SSC-W", 4.8,4.95), ("mCherry-A", 2, 4)] # R6
DEFAULT_FILTERS = [("FSC-A", 4.2, 5.1), ("SSC-W", 4.8,4.95), ("mCherry-A", 2.25, 4.1)] # R1
UNBUDDED_FSCA_MEAN = 10**4.64
BUDDED_FSCA_MEAN = 10**4.64
UNBUDDED_SSCW_MEAN = 10**4.81
BUDDED_SSCW_MEAN = 10**4.88
BY_RFP_MEAN = 10**3.5
RM_RFP_MEAN = 10**2.8

DEFAULT_PRIOR_RFP_MEANS = SP.zeros([3,1])
DEFAULT_PRIOR_RFP_MEANS[0,0] = SP.log10(BY_RFP_MEAN) # BY
DEFAULT_PRIOR_RFP_MEANS[1,0] = SP.log10(RM_RFP_MEAN) # RM
DEFAULT_PRIOR_RFP_MEANS[2,0] = SP.log10(RM_RFP_MEAN)-0.5 # Noise


"""
Attempts to parse an FCS (flow cytometry standard) file

@param filename path to the FCS file
@return tuple(vars,events)
    	vars: a dictionary with the KEY/VALUE pairs found in the HEADER
    	this includes the standard '$ABC' style FCS variable as well as any 
    	custom variables added to the header by the machine or operator
	
    	events: an [N x D] matrix of the data (as a Python list of lists)
    	i.e. events[99][2] would be the value at the 3rd dimension
    	of the 100th event
"""
def fcsextract(filename):
    fcs_file_name = filename

    fcs = open(fcs_file_name,'rb')
    header = fcs.read(58)
    version = header[0:6].strip()
    text_start = int(header[10:18].strip())
    text_end = int(header[18:26].strip())
    data_start = int(header[26:34].strip())
    data_end = int(header[34:42].strip())
    analysis_start = int(header[42:50].strip())
    analysis_end = int(header[50:58].strip())

    ###print "Parsing TEXT segment"
    # read TEXT portion
    fcs.seek(text_start)
    delimeter = fcs.read(1)
    # First byte of the text portion defines the delimeter
    ###print "delimeter:",delimeter
    text = fcs.read(text_end-text_start+1)

    #Variables in TEXT poriton are stored "key/value/key/value/key/value"
    keyvalarray = text.split(delimeter)
    fcs_vars = {}
    fcs_var_list = []
    # Iterate over every 2 consecutive elements of the array
    for k,v in zip(keyvalarray[::2],keyvalarray[1::2]):
        fcs_vars[k] = v
        fcs_var_list.append((k,v)) # Keep a list around so we can print them in order

    #from pprint import pprint; pprint(fcs_var_list)
    if data_start == 0 and data_end == 0:
        data_start = int(fcs_vars['$DATASTART'])
        data_end = int(fcs_vars['$DATAEND'])

    num_dims = int(fcs_vars['$PAR'])
    ###print "Number of dimensions:",num_dims

    num_events = int(fcs_vars['$TOT'])
    ###print "Number of events:",num_events

    # Read DATA portion
    fcs.seek(data_start)
    #print "# of Data bytes",data_end-data_start+1
    data = fcs.read(data_end-data_start+1)

    # Determine data format
    datatype = fcs_vars['$DATATYPE']
    if datatype == 'F':
        datatype = 'f' # set proper data mode for struct module
        ###print "Data stored as single-precision (32-bit) floating point numbers"
    elif datatype == 'D':
        datatype = 'd' # set proper data mode for struct module
        ###print "Data stored as double-precision (64-bit) floating point numbers"
    else:
        assert False,"Error: Unrecognized $DATATYPE '%s'" % datatype
    
    # Determine endianess
    endian = fcs_vars['$BYTEORD']
    if endian == "4,3,2,1":
        endian = ">" # set proper data mode for struct module
        ### print "Big endian data format"
    elif endian == "1,2,3,4":
        ###print "Little endian data format"
        endian = "<" # set proper data mode for struct module
    else:
        assert False,"Error: This script can only read data encoded with $BYTEORD = 1,2,3,4 or 4,3,2,1"

    # Put data in StringIO so we can read bytes like a file    
    data = StringIO(data)

    ###print "Parsing DATA segment"
    # Create format string based on endianeness and the specified data type
    format = endian + str(num_dims) + datatype
    datasize = struct.calcsize(format)
    ###print "Data format:",format
    ###print "Data size:",datasize
    events = []
    # Read and unpack all the events from the data
    for e in range(num_events):
        event = struct.unpack(format,data.read(datasize))
        events.append(event)
    
    fcs.close()
    return fcs_vars, events


""" Read all data in an FCS file
@param filename path to fcs file
@param colnames_tostore list of column names to read from FCS. If None, all FCS features are returned.
@param log whether to apply lambda x:log10(x+1) to the read data
@return tuple(colnames_tostore, data) where data is Nxlen(colnames_tostore) 
"""
def read_fcs(filename, colnames_tostore=DEFAULT_COLNAMES, log=False):
    vars, events = fcsextract(filename)
    events = SP.array(events)
    colnames = [vars["$P%dN"%(i+1)].replace("GFP-A", "FITC-A") for i in range(events.shape[1])]
    interesting_cols = range(len(colnames)) # keep all columns by default
    if colnames_tostore is not None: # if another selection given
        interesting_cols = [colnames.index(c) for c in colnames_tostore] # filter on them
        
    d = events[:, interesting_cols] # retain only interesting columns
    I = SP.where(d.min(axis=1) > 0)[0]
    if len(I) > 0: d = d[I] # retain rows with nonnegative values
    if log: return colnames_tostore, SP.log10(d+1) # log10 if asked
    return colnames_tostore, d




""" Return rows of matrix data that pass all filters (linear gates).
@param data NxC array of floats
@param colnames length-C array of data column names
@param filters list of triplets of (column name, min_value, max_value)
@return MxC array of data points that pass all filters
"""
def filter_fcs(data, colnames, filters):
    I = SP.ones(data.shape[0], bool) # default - retain all
    for f in filters:
        col = list(colnames).index(f[0]) # for each filter, get column
        I = I & (data[:,col] > f[1]) & (data[:,col] < f[2]) # and update which rows to retain
    #LOG.debug("Filtered out %d of %d data points"%(len(data) - sum(I), len(data)))
    return data[I] # return wanted data


def plot_data(data, fsc,ssc,mch, filters):
    import pylab as PL
    # Scatter plot of important features
    PL.figure(figsize=(14,9))
    PL.subplot(121)
    PL.plot(data[:,fsc], data[:,ssc], ".", markersize=14, alpha=0.15)
    # including box with current filters
    (x1,x2), (y1,y2) = filters['f'], filters['s']
    PL.plot([x1,x1],[y1,y2], 'r-')
    PL.plot([x2,x2],[y1,y2], 'r-')
    PL.plot([x1,x2],[y1,y1], 'r-')
    PL.plot([x1,x2],[y2,y2], 'r-')
    PL.xlabel("FSC")
    PL.ylabel("SSC")
    PL.subplot(122)
    PL.plot(data[:,fsc], data[:,mch], ".", markersize=14, alpha=0.15)
    (x1,x2), (y1,y2) = filters['f'], filters['m']
    PL.plot([x1,x1],[y1,y2], 'r-')
    PL.plot([x2,x2],[y1,y2], 'r-')
    PL.plot([x1,x2],[y1,y1], 'r-')
    PL.plot([x1,x2],[y2,y2], 'r-')
    PL.xlabel("FSC")
    PL.ylabel("mCherry")
    PL.show()
    

""" Return a list of filters (tuples of column name, min value, max value)
@param data NxC SP.array of data points
@param colnames length-C list of features
@return list of tuples (colname, minvalue, maxvalue) """
def input_filters(data, colnames):
    # Plot FSC-SSC and FSC-mCherry scatter plots
    names = {"f":"FSC-A", "s":"SSC-W", "m":"mCherry-A"}
    filters = {"2":{"f":(4.75, 5.4), "s":(4.81, 4.95), "m":(2.2, 3.7)}, "3":{"f":(4.75, 5.4), "s":(4.8, 4.95), "m":(2.2, 3.8)},
               "1": {"f":(4.2, 5.1), "s":(4.8, 4.92), "m":(2.4, 4)}}
    print "Write down filters for FSC, SSC, mCherry. Current"
    for f in sorted(filters):
        print "\t", f, ":", filters[f], ["", " (default)"][f == "1"]
    fsc,ssc,mch = colnames.index("FSC-A"), colnames.index("SSC-W"), colnames.index("mCherry-A")
    plot_data(data, fsc, ssc, mch, filters["1"])

    print "Pick one of 1..%d, change the default setting 1 (c) or continue (enter)?"%(len(filters)),
    answer = sys.stdin.readline().strip()
    if len(answer) == 0: # if enter, just pick default
        answer = "1"        
    if answer.lower() in map(str, range(1, len(filters) + 1)):  # if one of predefined ones, set the filters and return
        result, filterset = [], answer.lower()
        for f in filters[filterset]: result.append((names[f], filters[filterset][f][0], filters[filterset][f][1]))
        return result
    
    # else, update as many as required
    f = [None,None,None]
    print "Changing default filter"
    while len(f) > 0:
        print "Enter filter (colname, minval, maxval); empty line to finish:",
        f = sys.stdin.readline().strip().split()
        if len(f) > 1:
            filters["1"][f[0]] = (float(f[1]), float(f[2]))
    # Done updating, return
    print "New filters", filters
    result = []
    for f in filters["1"]: result.append((names[f], filters["1"][f][0], filters["1"][f][1]))
    return result


""" Read all fcs files in a given plate of a screen of an experiment, and randomly pick n_rnd_points from them """
def read_plate_fcs_random_data(screen, plate, experiment="Pilot_screen_BYxRM", n_rnd_points=100000, colnames=DEFAULT_COLNAMES, filters=DEFAULT_FILTERS, n_files=48, external_hd=False):
    files = SP.array(glob.glob("%s/%s/%s/%s/*.fcs"%([DATA_DIR + "/cytometry","/Volumes/BACKUP/"][external_hd], experiment, screen, plate)))
    files = SP.random.choice(files, n_files, replace=(n_files > len(files)))

    # 0. if manual filter, plot data, create filters, repeat
    data = SP.zeros([0,len(colnames)])    
    if filters == "manual":
        for f in files[0:10]: data = SP.concatenate((data,read_fcs(f, colnames_tostore=colnames, log=True)[1]))
        filters = input_filters(data[SP.random.choice(range(len(data)), 3000, replace=False),:], colnames)
        
    data = SP.zeros([0,len(colnames)])
    alldata = SP.zeros([0,len(colnames)])
    for f in files:
        #LOG.debug(f)
        cols, dat = read_fcs(f, colnames_tostore=colnames, log=True)
        alldata = SP.concatenate((alldata,dat[SP.random.choice(len(dat), 1.*n_rnd_points/len(files))]))
        dat = filter_fcs(dat, cols, filters)
        if len(dat) == 0: continue
        data = SP.concatenate((data,dat[SP.random.choice(len(dat), 1.*n_rnd_points/len(files))]))

    # Further filter out top/bottom 2% of each feature (~15% total data) to avoid overfitting to a weird subset
    n_points = data.shape[0]
    for i, col in enumerate(colnames):
        I = SP.argsort(data[:,i])
        data = data[I[int(n_points*0.02):int(n_points*0.98)],:]

    # return data retrieved, column names of data, filters applied
    return data, cols, filters

