from parts2014_gfpvar.tools.common import *
from parts2014_gfpvar.tools.io import *
from parts2014_gfpvar.analysis.array.util import get_all_metadata
import scipy as SP
import pylab as PL
import scipy.stats as ST
import glob
import pdb


""" Write observed allele and inferred allele frequency data for a sample into a file
@param sample sample to write
@effects creates file in seq/output/[sample name]_af.tab which contains segregating loci, coverage, observed alleles, posterior allele frequency """
def write_sample_sequence_data(sample, rec_cutoff):
    (samplename, rep, gfppop, orf, assay, diagnosis, comment, discovery_set) = get_all_metadata(sample_set="seq")[sample]
    if (comment == "discard") or (discovery_set[0:4] == "none"):
        LOG.debug("Skipping sample %s (comment=%s, discovery set=%s)"%(sample, comment, discovery_set))
	return
    LOG.debug("Outputting data for %s"%sample)
    data = get_sample_seq_data(sample, rec_cutoff)
    chrnames = get_seq_chr_names()
    ofh = file("%s/seq/txt_output/%s_%s_%s_af.tab"%(DATA_DIR, orf, gfppop.replace("_","-"), rep[-2:]), 'w')
    ofh.write("#Chrm\tLoc\tCoverage\t#BY\t#RM\tf_ML\tsd(f_ML)\tf_est\tvar(f_est)\tbad_locus?\n")
    for chrm in sorted(data):
        ml_m, post_m, post_params, bad_loc, loci, coverage = data[chrm] # ml_m is the BY allele frequency
        n_by, n_rm = SP.array((coverage*ml_m + 0.5), int), SP.array(coverage*(1-ml_m), int) # number of observed alleles
        sd_ml, sd_post = ((1-ml_m)*ml_m/coverage)**0.5, (post_params.prod(axis=1)**0.5)/post_params.sum(axis=1)/((post_params.sum(axis=1) + 1)**0.5)
                                                                
        for i in range(len(loci)):
            ofh.write("%s\t%d\t%d\t%d\t%d\t%.3f\t%.3f\t%.3f\t%.3f\t%s\n"%(chrnames[chrm], loci[i], coverage[i], n_by[i], n_rm[i], 1.-ml_m[i], sd_ml[i], 1.-post_m[i], sd_post[i], str(bad_loc[i])))
    ofh.close()


""" Return summaries for sample sequencing.
@param sample name of sample
@return map of chromosome name -> [ml_mean, post_mean, post_params, bad_loc?, loci, coverage] (all length-L arrays)
"""
def get_sample_seq_data(sample, rec_cutoff=0.90):
    if sample is None: return None
    if rec_cutoff is None:
        return cl("%s/seq/afs/%s_af.pickle"%(DATA_DIR, sample))
    return cl("%s/seq/afs/%s_af_rec-%.2f.pickle"%(DATA_DIR, sample, rec_cutoff))



""" Calculate fraction of sequenced reads that come from true segregants in haploid selection. For that, estimate non-SGA marker allele frequency, and from it, under assumption that the contaminating material is at 50% allele frequency, the fraction of the sample. This is  done for each SGA locus separately.
@param sample sequencing sample to test
@param locs list of SGA loci to test, each item (chrm, start, end, expected reference allele frequency)
@return fraction of non-SGA selected sample for each location in list of locs"""
def calc_sample_purity(sample, locs=(("II", 230000, 240000, 0), ("V", 30000, 40000, 0))):
    data = get_sample_seq_data(sample)
    result = []
    
    for chrm, start, end, expected_af in locs:
        d = data[[x for x in get_seq_chr_names() if get_seq_chr_names()[x] == chrm][0]] # for all data on the chromosome and 
        I = SP.where((abs(d[4] - 0.5*(start + end)) < 0.5*(end - start)) & ~d[3])[0] # in the region, and not a bad allele
        obs_af = d[0][I].mean() # calculate average observed allele frequency
        result.append((obs_af - 0.5)/(expected_af - 0.5)) # transform it to fraction of contaminating material, and append to result
    return result




def read_kapa_result(filename, filter_small=True):
    ifh = file(filename, 'r')
    for i in range(2): ifh.next() # lines 1-2 are "Raw Data                Experiment - 2013-03-29_Leo_Kapa-rest (Run on LCS480 1.5.0.39)
                                  #                SamplePos       SampleName      Prog#   Seg#    Cycle#  Time    Temp    465-510"
    well_values = {}
    
    for l in ifh: # for each line in input file that looks something like "B3      Sample 15       2       2       23      2827150 59.93   40.59"
        d = l.strip().split("\t")
        well, cycle, val = d[0], int(d[4]), float(d[7])
        if well not in well_values:
            well_values[well] = SP.zeros(36)
        well_values[well][cycle] = val

    if filter_small:
        to_remove = []
        for well in well_values:
            if well_values[well].max() < 8:
                to_remove.append(well)
        for w in to_remove:
            well_values.pop(w)
            
    return well_values, get_kapa_well_samples(filename)



def get_kapa_well_samples(filename):
    res = {}
    rel_conc = 1.
    ifh = file(filename.replace(".txt", ".samples"), 'r')
    for l in ifh:
        if l[0] == "#":
            if l.count("Relative_conc") > 0:
                rel_conc = float(l.strip().split("=")[1])
                print rel_conc
            continue
        d = l.strip().split("\t")
        res[d[1]] = d[0]

    return res, rel_conc



# calculate x-intercepts of the two sample rows
def calc_intercepts(values):
    all_intercepts = {}

    for well in sorted(values):
        log2v = SP.log2(values[well] + 1)
        I = SP.where(abs(log2v - 3.5) <= 1.)[0] # values between 2.5 and 4.5 - linear phase
        slope, intercept = ST.linregress(I, log2v[I])[0:2] # get slope and intercept of the fit
        xloc = -intercept/slope # and x-coordinate of other axis intercept
        all_intercepts[well] = xloc
        
    return all_intercepts


def plot_curves(values, samples, rows="ABCDEFGH"):
    cols = 'bgcymk'
    PL.figure(figsize=(12,10))
    
    for well in sorted(values):
        log2v = SP.log2(values[well] + 1)
        if samples[well].count("pM") > 0:
            PL.plot(log2v, 'r-', linewidth=3)
        elif well[0] in rows:
            col = cols[(int(well[1:]) - 1)%6]
            PL.plot(log2v, '%s-'%col)
            
    PL.xlabel("Cycle")
    PL.ylabel("Fluorescence")
    PL.title("qPCR curves - %d samples"%(len(values)))
    #PL.xlim(10,23)
    #PL.ylim(2,5)
    PL.show()


def analyse_kapa_results(filename, debug=True, plot=True, output=True, filter_small=True):
    values, (samples, rel_conc) = read_kapa_result(filename, filter_small=filter_small)
    all_intercepts = calc_intercepts(values)
    if plot:
        for rows in "AB", "CD", "EF":
            if "%s1"%(rows[0]) in samples:
                plot_curves(values, samples, "AB%s"%rows)
    if output:
        ofh = file(filename.replace(".txt",".output"), "w")
        ref_wells = ["A2", "B2"]
        if filename.count("2013-04-14_Leo_Kapa-B6-B16-B17") > 0: ref_wells = ["A11","A12"]
        elif filename.count("Adam") > 0: ref_wells = ["F8","G8", "H8"]
        ref_intercepts = [all_intercepts[w] for w in ref_wells]
        print filename, ref_wells
        all_samples = SP.unique(samples.values())
        for s in all_samples: # for each sample
            sample_wells = [w for w in samples if samples[w] == s]
            sample_intercepts = [all_intercepts[w] for w in sample_wells if w in all_intercepts]
            sample_rel_conc = [2**(SP.mean(ref_intercepts) - sample_intercept) for sample_intercept in sample_intercepts]
            sample_conc = SP.mean(sample_rel_conc)*rel_conc*2.
            if s.count("pM") > 0: sample_conc = sample_conc/rel_conc
            ofh.write("%s\t%.3f nM\n"%(s, sample_conc/1000))
            if debug:
                #print "%s (wells %s): %.3f nM"%(s, " ".join(sample_wells), SP.mean(sample_rel_conc)), # assumption here is that it's a 500x difference in dilution
                print "%s (wells %s): %.3f pM"%(s, " ".join(sample_wells), sample_conc), # assumption here is that it's a 25x difference
                print "(intercept: %.2f (%s),"%(SP.mean(sample_intercepts), " ".join(["%.2f"%x for x in sample_intercepts])),
                print "increase (%s)"%(" ".join(["%.2f"%x for x in sample_rel_conc]))
