from parts2014_gfpvar.tools.common import *
from parts2014_gfpvar.analysis.qtl.util import *
from util import *

NEUTRAL_SAMPLE = ["", "Sample2_ref"] # comparison against this sample when only low or high is measured
AF_CHANGE_EXTRA = 0.02 # extra AF slack for calling a peak region to avoid overcalling
SD_CHANGE_EXTRA = 0.3 # extra SD slack for calling a peak region to avoid overcalling


""" Run the QTL calling function on all the sample pairs
@param af_cutoff minimum absolute value allele frequency change to call a QTL
@param sd_cutoff minimum number of SDs difference to call a QTL
@param length_cutoff minimum length of QTL region
"""
def call_all_set_qtls(signal_cutoff, sd_cutoff, length_cutoff, singletons=False, sequencing=True):
    sample_pairs = get_qtl_sample_pairs(singletons, sequencing)
    for pair in sample_pairs:
        sample_high, sample_low = sample_pairs[pair]
        if sample_high is None: sample_high = NEUTRAL_SAMPLE[sequencing]
        if sample_low is None: sample_low = NEUTRAL_SAMPLE[sequencing]
        call_set_qtl(pair, sample_high, sample_low, af_cutoff, sd_cutoff, length_cutoff)


""" Call QTLs for a pair of samples.
@param set_name name of the set of samples (used in QTL output)
@param sample_high GFP high sample. If None, is replaced with AF=0.5, SD=0 to test for difference from neutrality of the other sample
@param sample_low GFP low sample. If None, is replaced with AF=0.5, SD=0  to test for difference from neutrality of the other sample
@param af_cutoff see call_all_qtls
@param sd_cutoff see call_all_qtls
@param length_cutoff see call_all_qtls"""
def call_set_qtl(set_name, sample_high, sample_low, af_cutoff, sd_cutoff, length_cutoff):
    LOG.debug("Calling QTLs for set %s (GFP high sample %s, low sample %s) with cutoffs AF=%.3f, SDs=%.3f, length=%d"%(set_name, sample_high, sample_low, af_cutoff, sd_cutoff, length_cutoff))
    qtls = []
    data = [get_sample_seq_data(sample) for sample in sample_high, sample_low]

    for chrm in data[0]: # for each chromosome
        if chrm not in data[1]: continue # present in both samples
        chrdata = [d[chrm] for d in data] # get data for this chromosome
        diffs, sds, strong_sig, lenient_sig, all_locs = calc_af_diff(chrdata, af_cutoff, sd_cutoff) # calculate differences for posteriors

        for i in range(len(all_locs)): # scan loci left to right
            if lenient_sig[i].all(): # if change at the locus leniently significant by all measures
                is_qtl = False # not a QTL yet, but start scanning
                start, end, peak = i,i,i # start, end, strongest signal of peak region
                while (end < len(all_locs) - 1) and lenient_sig[end].all(): # while can extend peak to right, do so
                    end += 1
                    if strong_sig[end].all(): is_qtl = True # actual QTL if the stretch of leniently significant loci has at least one properly significant change
                    if abs(diffs[end]) > abs(diffs[peak]): peak = end # store location of strongest peak

                if is_qtl and (all_locs[end] - all_locs[start] > length_cutoff): # if strong enough change exists, and the region is long enough,
                    #genes = get_qtl_genes(chrm, all_locs[start], all_locs[end], all_locs[peak], string=True)
                    qtls.append((set_name, sample_high, sample_low, chrm, all_locs[peak], diffs[peak], sds[peak], all_locs[start], all_locs[end])) # store QTL
                elif is_qtl: LOG.debug("Skipping QTL chrm %s %d (af change=%.2f) as too short (%d)"%(chrm, all_locs[peak], diffs[peak], all_locs[end] - all_locs[start]))
                
                lenient_sig[start:end+1, :] = strong_sig[start:end+1, :] = False # remove the stretch just observed from future consideration

    output_seq_qtl(set_name, af_cutoff, sd_cutoff, length_cutoff, qtls)



def get_locus_overlap(l1,l2):
    all_loc = set(l1) & set(l2)
    res_loc = []
    I = [[],[]]
    for i,l in enumerate(l1):
        if l in all_loc:
            res_loc.append(l)
            I[0].append(i)
            I[1].append(SP.where(l2 == l)[0][0])
    return res_loc, I
    

def calc_af_diff(chrdata, af_cutoff, sd_cutoff):
    # 0. Get relevant parameters from data
    locs, mu, var = [d[4] for d in chrdata], [d[2][:,0]/(d[2].sum(axis=1)) for d in chrdata], [d[2].prod(axis=1)/((d[2].sum(axis=1)**2 + 1e-10)*(d[2].sum(axis=1) + 1)) for d in chrdata] 

    # 1. Calculate mean and variance for shared loci
    all_locs, I = get_locus_overlap(locs[0], locs[1])
    for i in range(2):
        mu[i] = mu[i][I[i]]
        var[i] = var[i][I[i]]
    diff = mu[0] - mu[1]
    sds = diff/((var[0] + var[1])**0.5)

    # 2. Calculate significances of allele frequency change (with and without extra wiggling), and number of SD allele frequency change
    lenient_sig, strong_sig = SP.ones([len(all_locs), 2]), SP.ones([len(all_locs), 2])
    strong_sig[:,0] = abs(diff) >= af_cutoff
    lenient_sig[:,0] = abs(diff) + AF_CHANGE_EXTRA >= af_cutoff
    strong_sig[:,1] = abs(sds) >= sd_cutoff
    lenient_sig[:,1] = abs(sds) + SD_CHANGE_EXTRA >= sd_cutoff
                           
    return diff, sds, strong_sig, lenient_sig, all_locs



def main():
    pass


if __name__ == '__main__':
    main()
