from gfpvar.tools.common import *
from leo.common import *
import scipy as SP
import glob
import os
import sys
import pdb

def read_pileup(filename, pileupfile, base_qual_cutoff=20, map_qual_cutoff=20):
    if os.path.exists(filename):
        return cl(filename)
    else:
        return create_pileup(pileupfile, filename, base_qual_cutoff, map_qual_cutoff)


def filter_nearby(locs, cutoff, I):
    last = 0
    res = []
    for i in I:
        if locs[i] - last > cutoff:
            last = locs[i]
            res.append(i)
    return res


def smooth(data, loc, rec_rate=80./(12000000), rec_cutoff=0.9, outlier_cutoff=0.2, include_bad=False, max_num_median_coverage=2.):
    res = SP.zeros([data.shape[0], 2])
    length_cutoff = -SP.log(rec_cutoff)/rec_rate
    print "length cutoff=%d"%length_cutoff, 
    coverage, max_coverage = data.sum(axis=1), SP.median(data.sum(axis=1))*max_num_median_coverage
    fixed = (data.prod(axis=1) == 0)

    for i in range(len(data)): # smooth each locus
        res[i,:] = 0
        n = 0
        for j in SP.where(abs(loc[i] - loc) <  length_cutoff)[0]: # add together information across all strain SNPs that are not too close
            d = data[j]*(0.05**(fixed[j])) + 1 # add prior of 1 observed base for each allele
            if d.sum() > max_coverage: res[i] += SP.exp(-rec_rate*abs(loc[i] - loc[j]))*d*(max_coverage/d.sum())   # if coverage high, normalise to max_coverage
            else: res[i] += SP.exp(-rec_rate*abs(loc[i] - loc[j]))*d   # if not high, just add all total mapped bases

    post_mean = res[:,0]/(res.sum(axis=1)+1E-6)
    init_mean = data[:,0]/(data.sum(axis=1) + 1E-10)
    bad = (abs(init_mean - post_mean) > outlier_cutoff) # bad ones are data points very divergent from estimates

    for i in range(len(data)): # smooth each locus
        res[i,:] = 0
        n = 0
        for j in filter_nearby(loc, cutoff=200, I=SP.where((abs(loc[i] - loc) <  length_cutoff) & ~bad)[0]): # add together information across all strain SNPs that are not too close
            d = data[j]*(0.05**(fixed[j])) + 1 # if allele fixed, divide coverage by 20
            if d.sum() > max_coverage: res[i] += SP.exp(-rec_rate*abs(loc[i] - loc[j]))*d*(max_coverage/d.sum())   # if coverage high, normalise to max_coverage
            else: res[i] += SP.exp(-rec_rate*abs(loc[i] - loc[j]))*d   # if not high, just add all total mapped bases

    post_mean = res[:,0]/(res.sum(axis=1))
    bad = (abs(init_mean - post_mean) > outlier_cutoff)
    notbad = SP.where(~bad)[0]
    return init_mean, post_mean, res, bad, loc, coverage


def assign_ref(data, ref):
    allele_index = {'A':0,'C':1,'G':2,'T':3,'*':4}
    res = SP.zeros([len(data), 2])
    
    for i in range(len(data)): # for each site
        ref_i = allele_index[ref[i]] # reference allele is the one from annotated in pileup
        alt_i = ((ref_i + 1) % 5) # alternative allele is the most common non-reference one
        for j in range(data.shape[1]):
            if (data[i,j] > data[i,alt_i]) and (j != ref_i): 
                alt_i = j
                
        res[i][0] = data[i][ref_i] # final allele counts are reference, 
        res[i][1] = data[i][alt_i] # and most common non-reference
    return res


def smooth_sample(sample, rec_cutoff=0.93):
    pileup = None
    if os.path.exists("%s/seq/%s/2_pileup/%s.combined.pickle"%(DATA_DIR,sample,sample)): 
        pileup = read_pileup("%s/seq/%s/2_pileup/%s.combined.pickle"%(DATA_DIR,sample,sample), sample)
        pileup['D'] = SP.array(pileup['D'])
    else:
        pileup = read_pileup("%s/seq/%s/2_pileup/%s.pileup.pickle"%(DATA_DIR,sample,sample), sample)
    locs = SP.array(pileup['L'])
    afs = {} # map of chromosome=>allele frequency tuple (initial mean, posterior mean, posterior beta parameters, bad sites indicator, loci)

    for chrm in SP.unique(locs[:,0]): # for each chromosome
        print sample, chrm
        I = SP.where(locs[:,0] == chrm)[0] # filter sites for this chromosome
        data = pileup['D'][I] # get sequence data for these sites
        ref = SP.array(pileup['refseq'])[I] # and the reference alleles 
        afs[chrm] = smooth(assign_ref(data, ref), SP.array(locs[I,1], int), rec_cutoff=rec_cutoff) # Based on the data, call allele frequencies

    os.system("mkdir -p %s/seq/%s/3_af"%(DATA_DIR, sample))  # and store them
    cdm(afs, "%s/seq/%s/3_af/%s_af_rec-%.2f.pickle"%(DATA_DIR, sample, sample, rec_cutoff))


def main():
    if len(sys.argv) > 1: 
        smooth_sample(sys.argv[1])
        return

    # if default case, run all samples
    for f in glob.glob("%s/seq/Sample4_*/2_pileup/*.combined.pickle"%DATA_DIR): # for each pileup file # need 34 for Sample4
        sample = f.split("/")[-3]
        if True or not os.path.exists("%s/seq/%s/3_af/%s_af.pickle"%(DATA_DIR, sample, sample)): # do not overwrite if already there
            print "Processing ", sample
            #continue
            os.system("submitjob python run_5_smooth.py %s"%(sample)) # smooth the sample
        else:
            print "Already processed: ", sample


if __name__ == '__main__':
    main()
