from parts2014_gfpvar.tools.common import *
from parts2014_gfpvar.tools.io import get_seq_chr_names
from util import *


""" Take all sequencing sets, calculate allele frequency differences between high and low populations """
def call_all_seq_qtls(recalc=False, rec_cutoff=0.9):
    meta = get_all_seq_metadata()
    for seq_set in SP.unique(meta[:,2]): # for each QTL set (pair of GFP-high, GFP-low)
        I = SP.where(meta[:,2] == seq_set)[0] # this accordingly should be length-2
        if (len(I) != 2) or (set(meta[I,3]) != set(["GFP_high", "GFP_low"])):
            LOG.debug("Skipping set %s - only %d samples (%s), expected 2 - high and low GFP"%(seq_set, len(I), str(meta[I,1])))
            continue
        hi, lo = I[SP.where(meta[I,3] == "GFP_high")[0][0]], I[SP.where(meta[I,3] == "GFP_low")[0][0]]
        if (not os.path.exists("%s/seq/diffs/%s.pickle"%(DATA_DIR, seq_set))) or recalc: # if differences already exist, assume QTL calls also exist
            if seq_set.count("a") > 0: # don't calculate array AF differences, assume they are given
                pass
            else:
                calc_af_diffs(meta[hi,0], meta[lo,0], seq_set, save=True, rec_cutoff=rec_cutoff)
            call_qtls(seq_set)


""" Call QTLs for call set
@param af_lenient - rec_cutoff - see call_qtls
@param save whether to output QTLs
"""
def call_clean_set_seq_qtls(af_lenient=0.1, sd_lenient=3, af_stringent=0.15, sd_stringent=5, length_cutoff=0, rec_cutoff=0.9, save=False):
    LOG.debug("Calling clean set of QTLs. af_lenient=%.2f, sd_lenient=%.1f, af_stringent=%.2f, sd_stringent=%.1f, length_cutoff=%d, rec_cutoff=%.2f, save=%s"%(af_lenient, sd_lenient, af_stringent, sd_stringent, length_cutoff, rec_cutoff, str(save)))
    qtls = read_qtls("all", af_lenient, sd_lenient, af_stringent, sd_stringent, length_cutoff) 
    meta = get_all_seq_metadata()
    chrgenes = read_chr_genes()
    callsets = SP.unique(meta[:,-1]) # all the discovery sets
    result = [q for q in qtls if q[0] in callsets] # all QTLs from the discovery sets

    if save:
        ofh = file("%s/qtl/clean_afl-%.2f_afs-%.2f_sdl-%.1f_sds-%.1f_len-%d.tab"%(DATA_DIR, af_lenient, af_stringent, sd_lenient, sd_stringent, length_cutoff), 'w')
        ofh.write("Gene\tChrm\tPeak\tStart\tEnd\tLength\tAF_peak\tSD_peak\tnumSD_peak\n")
        for q in result:
            gene = meta[SP.where(meta[:,2] == q[0])[0][0],4]
            qtype = get_qtl_type(q[1], int(q[2]), gene, chrgenes[q[1]])
            LOG.debug("Handling QTL from callset %s, qtype is %s"%(q[0], qtype))
            if qtype[0:3] != "SGA": # no SGA locus or self linkage - output. qtype could still be rDNA, or one of the main peaks, or unknown
                if q[0].count("a") > 0 and abs(float(q[8])) > 0.4: # if array QTL, denote that the "AF" differences are really not; also apply extra filter
                    ofh.write(gene + "\t" + "\t".join(q[1:8]) + '\t"' + q[8] + '"\t' + "\t".join(q[9:]) + "\n")
                else:
                    ofh.write(gene + "\t" + "\t".join(q[1:]) + "\n")
        ofh.close()
    return result


""" Call QTLs for a sequencing set
@param seq_set unique name for a pair of GFP-high, GFP-low samples (e.g. Sample4_44_sR1)
@param af_lenient minimum allele frequency change to start a putative QTL region
@param sd_lenient minimum SD of allele frequency change to start a putative QTL region
@param af_stringent minimum allele frequency change in a putative QTL region to call it a QTL
@param sd_stringent minimum SD of allele frequency change in a putative QTL region to call it a QTL
@param length_cutoff minimum length of a putative QTL region to be called a QTL
@param rec_cutoff parameter used in allele frequency inference
@param peak_cutoff "relaxation" from maximum allele frequency change at the QTL peak that is considered a candidate region for the causal gene
"""
def call_qtls(seq_set, af_lenient=0.1, sd_lenient=3, af_stringent=0.15, sd_stringent=5, length_cutoff=0, rec_cutoff=0.9, peak_cutoff=0.03):
    diffs = cl("%s/seq/diffs/%s_rec-%.2f.pickle"%(DATA_DIR, seq_set, rec_cutoff))
    chrnames, chrgenes = get_seq_chr_names(), read_chr_genes()
    qtls = []
    
    for chrm in diffs:
        locs, delta, sd = diffs[chrm]
        lenient = (abs(delta) > af_lenient) & (abs(delta)/sd > sd_lenient)
        stringent = (abs(delta) > af_stringent) & (abs(delta)/sd > sd_stringent)
        
        for i in range(len(locs)): # scan loci left to right
            if lenient[i]: # if change at the locus leniently significant by all measures
                is_qtl = False # not a QTL yet, but start scanning
                start, end, peak = i,i,i # start, end, strongest signal of peak region, region we're confident in having the gene in
                while (end < len(locs) - 1) and lenient[end]: # while can extend peak to right, do so
                    end += 1
                    if stringent[end]: is_qtl = True # actual QTL if the stretch of leniently significant loci has at least one properly significant change
                    if abs(delta[end]) > abs(delta[peak]): peak = end # store location of strongest peak

                if is_qtl and (locs[end] - locs[start] > length_cutoff): # if strong enough change exists, and the region is long enough,
                    centre = SP.where((locs >= locs[start]) & (locs <= locs[end]) & (abs(delta[peak]) <= abs(delta) + peak_cutoff))[0] # all loci with change within peak_cutoff
                    genes = get_sorted_region_genes(chrgenes[chrnames[chrm]], locs[peak], locs[centre.min()], locs[centre.max()])
                    qtls.append((chrnames[chrm], locs[peak], delta[peak], sd[peak], abs(delta[peak])/sd[peak], locs[start], locs[end], locs[centre.min()], locs[centre.max()], genes)) # store QTL
                elif is_qtl: LOG.debug("Skipping QTL chrm %s %d (af change=%.2f) as too short (%d)"%(chrm, locs[peak], delta[peak], locs[end] - locs[start]))
                
                lenient[start:end+1] = stringent[start:end+1] = False # remove the stretch just observed from future consideration

    ofh = file("%s/qtl/%s_afl-%.2f_afs-%.2f_sdl-%.1f_sds-%.1f_len-%d.tab"%(DATA_DIR, seq_set, af_lenient, af_stringent, sd_lenient, sd_stringent, length_cutoff), 'w')
    ofh.write("#Set\tChrm\tPeak\tStart\tCentre_start\tCentre_end\tEnd\tLength\tAF_peak\tSD_peak\tnumSD_peak\tCentre_genes\n")
    for (chrm, peak, d, s, sds, start, end, c_start, c_end, genes) in qtls:
        ofh.write("%s\t%s\t%d\t%d\t%d\t%d\t%d\t%d\t%.3f\t%.3f\t%.1f\t%s\n"%(seq_set, chrm, peak, start, c_start, c_end, end, end-start, d, s, sds, ",".join(genes)))
    ofh.close()
    

def calc_af_diffs(s1, s2, set_name, save=False, rec_cutoff=0.90):
    p1, p2 = get_sample_seq_data(s1, rec_cutoff), get_sample_seq_data(s2, rec_cutoff)
    result = {}
    LOG.debug("Calculating allele frequency diffs. Sample1=%s, sample2=%s, set_name=%s, rec_cutoff=%.2f, save=%s"%(s1,s2,set_name, rec_cutoff, str(save)))

    for chrm in p1:
        if chrm.count("1224") > 0: continue # skip mitochondrion
        m1, b1,l1 = p1[chrm][1], p1[chrm][2], p1[chrm][4] # beta, loc
        m2, b2,l2 = p2[chrm][1], p2[chrm][2], p2[chrm][4]
        v1, v2 = b1.prod(axis=1)/((b1.sum(axis=1)**2 + 1e-6)*(b1.sum(axis=1) + 1)), b2.prod(axis=1)/((b2.sum(axis=1)**2 + 1e-6)*(b2.sum(axis=1) + 1))
        locs,diffs,sds = [],[],[]
        common_set = set(l1) & set(l2)
        common = {}
        for c in common_set: common[c] = True
        for l in sorted(common.keys()):
            i1 = SP.where(l1 == l)[0][0]
            i2 = SP.where(l2 == l)[0][0]
            if SP.isnan(m1[i1] - m2[i2]): continue
            locs.append(l)
            diffs.append((1.-m1[i1]) - (1. - m2[i2])) # the means are BY allele frequencies. Calculate the differences in _RM_ frequencies, as all rest of the statistics about this.
            sds.append((v1[i1] + v2[i2])**0.5)
        result[chrm] = (SP.array(locs), SP.array(diffs), SP.array(sds))
        
    if save: cdm(result, "%s/seq/diffs/%s_rec-%.2f.pickle"%(DATA_DIR, set_name, rec_cutoff))
    return result


def main():
    call_all_seq_qtls(recalc=True)
    call_clean_set_seq_qtls(save=True)


if __name__ == '__main__':
    main()
