from util import *
from parts2014_gfpvar.analysis.qtl.util import *
from parts2014_gfpvar.tools.util import get_glucose_smith_qtls_peakavoid


def get_eqtl_pqtl_overlap(debug=False, nearby=50000):
    pq = read_qtls("clean")
    eq = get_orf_mrna_qtls()
    vals = []
    filtered, total, rep = 0, 0, 0
    if debug: print "ORF\teQTL_ch\teQTL_l\teQTL_L\teQTL_effect\tpQTL_loc\tpQTL_effect\tDiagnosis"

    for orf in eq: # for each ORF that we sorted and has eQTLs
        for (chrm, loc, lod, effect) in eq[orf]: # for each QTL for one of these orfs
            if (chrm == "XV" and abs(loc - 170000) < 30000) or (chrm == "II" and abs(loc - 556000) < 30000) or (chrm == "II" and abs(loc - 328000) < 30000) or (chrm == "III" and abs(loc - 200000) < 20000): 
                filtered += 1
                continue
            total += 1
            p_effect = get_average_selection_signal(orf, chrm, loc)[0]
            # 1. record overlap of effect size / allele frequencies at the eQTL locus
            vals.append([effect, p_effect, SP.nan, lod]) # eQTL effect, difference in signal from pQTL, effect size in nearby pQTL (SP.nan if none)
            # 2. see whether a pQTL is nearby
            I = SP.where((pq[:,0] == orf) & (pq[:,1] == chrm))[0] # all pQTLs for ORF on this chromosome
            if debug: print "%s\t%s\t%d\t%.1f\t%.3f"%(orf, chrm, loc, lod, effect),
            if len(I) == 0: # If no pQTLs, skip
                if debug:
                    if effect*p_effect > 0 and abs(p_effect) > 0.1: print "\t\t\t\t(%.2f)\t\t\tooo"%p_effect
                    else:  print "\t\t\t\t(%.2f)\t\t\t..."%p_effect
                continue
            q = pq[I[SP.argmin(abs(SP.array(pq[I,2],int) - loc))]] # nearest pQTL
            if (int(q[3]) > loc + nearby) or (int(q[6]) < loc - nearby): # if not near enough, skip
                if debug:
                    if effect*p_effect > 0 and abs(p_effect) > 0.1: print "\t\t\t\t(%.2f)\t\t\tooo"%p_effect
                    else:  print "\t\t\t\t(%.2f)\t\t\t..."%p_effect
                continue
            #if debug: print "\t\t%s\t\t%s\t\t%s\t%s\t%s"%(q[2],q[6],q[3],q[4],q[8]),
            if debug: print "\t\t%s\t\t%s\t\t"%(q[2],q[8]),
            if q[8].count('"') > 0: vals[-1][2] = float(q[8].replace('"', ''))/2. # effect size of pQTL (from arrays)
            else: vals[-1][2] = float(q[8]) # effect size of pQTL
            if vals[-1][2]*effect > 0: # if right direction, replicated :)
                if debug: print "\t***"
                rep += 1
            else:
                vals[-1][2] = SP.nan
                if debug: print "\t---", orf, chrm, loc

    if debug: print "\n\nTotal: %d, rep: %d, filtered: %d"%(total, rep, filtered)
    return SP.array(vals)



def get_pqtl_eqtl_signal(debug=True, by=False):
    ROMAN = "0 I II III IV V VI VII VIII IX X XI XII XIII XIV XV XVI".split()
    mrna_data = read_mrna_data()
    eqtls = get_glucose_smith_qtls_peakavoid(avoid_peaks=[], lod_cutoff=4.)
    qtls = SP.loadtxt("%s/paper/table_S7.tab"%DATA_DIR, delimiter="\t", dtype=object)
    vals = []
    
    for q in qtls:
        if q[0] == "Gene": continue  # first one is header
        if q[8].count('"') > 0: q[8] = float(q[8].replace('"',''))/2. # array signal - get to about same scale
        orf, chrm, loc, effect, pval = q[0], q[1], int(q[2]), float(q[8]), float(q[13])
        orf_eqtls = []
        if orf in eqtls: # if there are eQTLs for this ORF
            for echrm, marker, lod in eqtls[orf]: # go through them
                if ROMAN[int(echrm)] != chrm or abs(loc - int(marker)) > 200000: # if not the same as pQTL
                    orf_eqtls.append([echrm,marker,lod])
        vals.append([effect, get_mrna_effect_signal(orf, chrm, loc, covariates=orf_eqtls, mrna_data=mrna_data, by=by), pval])
        if ((vals[-1][0] > 0.3) and (vals[-1][1] < 0.1)) or ((vals[-1][0] < -0.3) and (vals[-1][1] > -0.1)):
            if debug: print orf, chrm, loc, vals[-1], pval

    return SP.array(vals)

 
def read_beyer_reduced_data(genes=["Gcy1", "Msc1"]):
    d = SP.loadtxt("%s/protein/beyer_reduced.tab"%DATA_DIR, delimiter="\t", dtype=object)
    inds = d[0][1:]
    res = {}
    for i in range(1,len(d)):
        o = d[i][0]
        if o in genes:
            res[o] = {}
            for j in range(1,len(inds)):
                res[o][inds[j]] = float(d[i][1+j].replace("NA", "NaN"))
    return res


def compare_beyer_pqtl_eqtl_signal(debug=True, orfs=[("YOR121C", "Gcy1",15,172000), ("YML128C", "Msc1",15,172000)]):
    snps_segs = cl("%s/mrna/snps_segregants.pickle"%DATA_DIR)
    snploc = cl("%s/mrna/snploc.pickle"%DATA_DIR)
    segs = cl("%s/mrna/individuals_segregants.pickle"%DATA_DIR)
    mrna_orfs, expr, mrna_segs = cl("%s/mrna/expr.pickle"%DATA_DIR)
    beyer_data = read_beyer_reduced_data(genes=[o[1] for o in orfs])

    for orf, name, qtl_chrm, qtl_loc in orfs:
        beyer = beyer_data[name]
        g = -1
        for i in range(len(mrna_orfs)):
            if mrna_orfs[i][1].count(orf) > 0: g = i
        vals = []
        gt = []
        I = SP.where(snploc[:,0] == qtl_chrm)[0]
        s = I[SP.argmax(abs(snploc[I,1] - qtl_loc))]
        for ind in beyer:
            idx = -1
            for j in range(len(mrna_segs)):
                if mrna_segs[j].count("glucose") > 0 and mrna_segs[j].split()[1].replace("_", ".") == ind:
                    idx = j
            vals.append([beyer[ind], expr[g,idx]])
            gt.append(snps_segs[idx,s,0])
        vals = SP.array(vals)
        gt = SP.array(gt)
        I1 = SP.where(gt == 1)[0]
        I0 = SP.where(gt == 0)[0]
        PL.plot(vals[I0,0], vals[I0,1], "b.")
        PL.plot(vals[I1,0], vals[I1,1], "r.")
        print orf, "Beyer:", SP.median(vals[I0,0]), SP.median(vals[I1,0])
        print orf, "mRNA:", SP.median(vals[I0,1]), SP.median(vals[I1,1])
        PL.show()


def get_peak_comparison(debug=False, signal_cutoff=0.15, smoother=25000, peak_cutoff_pqtl=4, peak_cutoff_eqtl=SP.log2(30), paper_numbers=False):
    ROMAN_NAMES = "I II III IV V VI VII VIII IX X XI XII XIII XIV XV XVI".split()
    pqtls = read_qtls("clean")
    chrlens, chrmids = get_cumulative_chrm_lengths(), []
    pqtl_count = SP.zeros(int(max(chrlens.values())/smoother) + 1)
    for q in pqtls:
        (orf, chrm, p,s,sc, ec, e,l, delta, sd, nsds) = q[0:11]
        if abs(float(delta.replace('"', ''))) > signal_cutoff:
            start, end = chrlens[chrm] + int(p) - smoother/2, chrlens[chrm] + int(p) + smoother/2
            pqtl_count[(start/smoother):(end/smoother + 1)] += 1 

    eqtl_count = SP.log2(get_all_mrna_qtls(smoother) + 1)
    prev, eqtl_peak, pqtl_peak = 0, [], []

    for i,c in enumerate(sorted(chrlens.values())[0:-1]):
        v1 = pqtl_count[prev/smoother:c/smoother] > peak_cutoff_pqtl
        v2 = eqtl_count[prev/smoother:c/smoother] > peak_cutoff_eqtl
        if i not in (1,2): # skip chromosomes 2 and 3 for this purpose - all the peak genes are not segregating in our cross
            eqtl_peak.extend(v2)
            pqtl_peak.extend(v1)
        prev = c
    if debug:
	if not paper_numbers:
	    print len(pqtl_count), len(eqtl_count)
	    print "Corr=", SP.corrcoef(eqtl_count, pqtl_count)[0,1]
	    print "Corr peak=", SP.corrcoef(eqtl_peak, pqtl_peak)[0,1]
	    for c in 4,5:
		f = 1.*sum(pqtl_count >= c)/len(pqtl_count)
		print "Observed %d peaks to have at least %d pQTLs, and 8 of 11 eQTL peaks to overlap. Genomic fraction in >=%d pQTL = %.2f; expected from random 13: %.1f (%.1f-fold fewer), expected from random 11: %.1f (%.1f-fold fewer)"%(sum(pqtl_count >= c), c, c, f, 13*f, 8./(13.*f), 11*f, 8./(11.*f))
	else:
	    c = 4
	    f = 1.*sum(pqtl_count >= c)/len(pqtl_count)
	    print "Observed %d/%d bins (%.1f%%) to have at least %d pQTLs, and 8 of 11 eQTL peaks to be in such a bin. Expected from random 11: %.1f (%.1f-fold fewer)"%(sum(pqtl_count >= c), len(pqtl_count), 100*f, c, 11*f, 8./(11.*f))
    return pqtl_count, eqtl_count, pqtl_peak, eqtl_peak

    
if __name__ == '__main__':
    get_eqtl_pqtl_overlap(debug=True)
#get_pqtl_eqtl_signal()
#compare_beyer_pqtl_eqtl_signal()
