import pylab as PL
from parts2014_gfpvar.tools.io import *
from parts2014_gfpvar.analysis.seq.util import *
from util import *
import glob
import time


""" Get replication signals of all QTLs in qtl_set (called with the af_stringent and sd_stringent parameters) for other qtl sets
@param qtl_set file name identifier to use; default="all" for all QTLs, could also be a single sample, e.g. A19
@param af_stringent minimum allele frequency difference for call
@param sd_stringent minimum SD in allele frequencies for call
@param qtl_region number of base pairs to optimise signal in the replication sample
@return map of qtl set->qtl->(QTL setname, rep setname, QTL AF, repro AF, QTL SD, rep SD, QTL #SD, rep #SD)
""" 
def calc_qtl_replication_stats(qtl_set="all", af_stringent=0.15, sd_stringent=5, qtl_region_span=1000, rec_cutoff=0.9):
    meta = get_all_seq_metadata()
    qtls = read_qtls(qtl_set, af_stringent=af_stringent, sd_stringent=sd_stringent)
    chrnames = get_seq_chr_names()
    set_qtl_reps = {} # map of set->qtl->values in all replicates

    # For each QTL, get its allele frequency change, and the allele frequency change in nearest locus in the other sample
    for q in qtls:
        if q[0] == "Set": continue
        (qset, chrm, peak, start, conf_start, conf_end, end, length, af_peak, sd_peak, nsds_peak) = q[0:11]
        seq_chrm = [c for c in chrnames if chrnames[c] == chrm][0]
        sample, rep = qset.split("_")
        if qset not in set_qtl_reps: set_qtl_reps[qset] = {}
        set_qtl_reps[qset][tuple(q)] = []
            
        all_sets = SP.unique(meta[SP.where(meta[:,1] == sample + "_minus")[0],2])
        for s in all_sets:
            if (s == qset) or (s.split("_")[0] != sample): continue # don't do reproducibility of self, take bad samples
            diffs = cl("%s/seq/diffs/%s_rec-%.2f.pickle"%(DATA_DIR, s, rec_cutoff))
            locs, delta, sds = diffs[seq_chrm]
            rep_peak = get_rep_peak(int(peak), locs, delta, float(af_peak), qtl_region_span)
            set_qtl_reps[qset][tuple(q)].append([qset, s, float(af_peak), delta[rep_peak], float(sd_peak), sds[rep_peak], float(nsds_peak), abs(delta[rep_peak])/(sds[rep_peak])])
    return set_qtl_reps



def get_rep_peak(peak_loc, locs, delta, af_peak, qtl_region_span=1000):
    rep_peak = SP.argmin(abs(locs - peak_loc))
    I = SP.where(abs(locs - peak_loc) < qtl_region_span/2)[0]
    if len(I) == 0: I = [rep_peak]
    rep_peak = I[SP.argmin(delta[I])]
    if float(af_peak) > 0: rep_peak = I[SP.argmax(delta[I])]
    return rep_peak
    

def calc_sample_pval(seq_set, qtl_region_span=1000, repro_af_cutoff=0.1, sd_cutoff=5, n_tosses=10000, rec_cutoff=0.90, save=True, recalc=False, seed=0):
    outfilename = "%s/seq/nulls/%s_rec-%.2f_span-%d_n-%d.pickle"%(DATA_DIR, seq_set, rec_cutoff, qtl_region_span, n_tosses)
    if os.path.exists(outfilename): return sum(cl(outfilename)[0] > repro_af_cutoff), sum(cl(outfilename)[1] > repro_af_cutoff)
    SP.random.seed(seed)
        
    diffs = cl("%s/seq/diffs/%s_rec-%.2f.pickle"%(DATA_DIR, seq_set, rec_cutoff))
    all_loc, all_d = [],[]
    for k in diffs:
        l,m,s = diffs[k]
        all_loc.extend(list(l + (0 if len(all_loc) == 0 else max(all_loc))))
        all_d.extend(list(m))
    all_loc, all_d = SP.array(all_loc), SP.array(all_d)

    # toss coin 10,000 times, picking a site on the genome
    locs = SP.array(SP.random.uniform(all_loc.min(), all_loc.max(), n_tosses), int)
    pos, neg = [], []
    for l in locs:
        pos.append(all_d[get_rep_peak(l, all_loc, all_d, 0.5, qtl_region_span)])
        neg.append(all_d[get_rep_peak(l, all_loc, all_d, -0.5, qtl_region_span)])
    pos, neg = SP.array(pos), SP.array(neg)
    if save:  cdm((pos, neg), outfilename)
    return sum(pos > repro_af_cutoff), sum(neg < -repro_af_cutoff)


def calc_all_sample_pval(qtl_region_span=1000, rec_cutoff=0.90, repro_af_cutoff=0.15, save=True, recalc=False):
    files = glob.glob("%s/seq/diffs/*rec-%.2f.pickle"%(DATA_DIR, rec_cutoff))
    samples = ["_".join(f.split("/")[-1].split("_")[0:2]) for f in files]
    for sample in samples:
        if sample.count("a") > 0: continue # no p-values for array QTLs, since no matching biological replicates
        t = time.time()
        pos_hits, neg_hits = calc_sample_pval(sample, qtl_region_span=qtl_region_span, repro_af_cutoff=repro_af_cutoff, rec_cutoff=rec_cutoff, save=save, recalc=recalc)
        LOG.debug("Calculated sample p-value for %s. qtl_region_span=%d, positive hits=%d, negative hits=%d, time=%.1f"%(sample, qtl_region_span, pos_hits, neg_hits, time.time() - t))


def get_callset(s, meta):  return meta[SP.where(meta[:,2] == s)[0][0],-1]


# Only used in one notebook - collapsed version of pqtlset_repro below
def get_pqtl_repro(sample_set="all", af_stringent=0.15, sd_stringent=5, qtl_region_span=1000, return_norepro=False):
    vals, nsds, pvals, repcounts = get_pqtlset_repro(sample_set, af_stringent, sd_stringent, qtl_region_span)
    avals, ansds = [], []
    for k in vals:
        if (not return_norepro) and (vals[k][1] == 0): continue # skip no repro ones for this purpose
        avals.append(vals[k])
        ansds.append(nsds[k])
    return SP.array(avals), SP.array(ansds)


def get_pqtlset_repro(sample_set="all", af_stringent=0.15, sd_stringent=5, qtl_region_span=1000):
    set_qtl_rep = calc_qtl_replication_stats(sample_set, af_stringent, sd_stringent, qtl_region_span)
    meta = get_all_seq_metadata()
    callsets = SP.unique(meta[:,-1])
    sample_nulls = {}
    for f in glob.glob("%s/seq/nulls/*_rec-0.90_span-%d_n-10000.pickle"%(DATA_DIR, qtl_region_span)):
        sample = f.split("/")[-1].split("r")[0][0:-1]
        sample_nulls[sample] = map(SP.array,cl(f))

    vals, nsds, pvals, repcounts = {},{},{},{}

    for callset in callsets:
        if (callset not in set_qtl_rep): # only look at QTLs from discovery set
            #print callset, "not in seq QTL reps"
            continue
        for qtl in set_qtl_rep[callset]: # qtl is a tuple (0-set, 1-chrm, 2-peak, 3-start, 4-centre-start, 5-centre-end, 6-end, 7-length, 8-AF difference, 9-SD(AF difference), 10-#SD difference, 11-genes
            vals[qtl] = [float(qtl[8]), 0] # AF difference
            nsds[qtl] = [float(qtl[10]),0] # SD difference
            pvals[qtl] = 0
            repcounts[qtl] = 0 
            for rep in set_qtl_rep[callset][qtl]: # (QTL setname, repro setname, QTL AF, repro AF, QTL SD, rep SD, QTL #SD, rep #SD)
                #if rep[0].count("a") + rep[1].count("a") > 0: continue # skip all array stuff for replication
                if rep[0].count("a") + rep[1].count("a") > 0: continue #
                if get_callset(rep[1], meta) != callset: continue # if the main callset and replication set are not a correct pair, skip
                repcounts[qtl] += 1
                #if ((vals[-1][0] < 0) and (rep[3] < vals[-1][1])) or ((vals[-1][0] > 0) and (rep[3] > vals[-1][1])):
                vals[qtl][1] += rep[3] # replicate QTL signal (AF difference)
                nsds[qtl][1] += rep[7] # replicate QTL signal (#SD difference)
                p = sum(sample_nulls[rep[0]][0] > rep[3])/10000.
                if vals[qtl][0] < 0: p = sum(sample_nulls[rep[0]][1] < rep[3])/10000.
                pvals[qtl] += p
            if repcounts[qtl] == 0: # if no replicates, store that have no information on the QTL
                vals[qtl][1] = SP.nan
                nsds[qtl][1] = SP.nan
                pvals[qtl] = SP.nan
            else:
                vals[qtl][1] = vals[qtl][1]/repcounts[qtl] # else average the average allele frequency 
                nsds[qtl][1] = nsds[qtl][1]/repcounts[qtl] # and standard deviation change signal across replicates
                pvals[qtl] = pvals[qtl]/repcounts[qtl]

    return vals, nsds, pvals, repcounts



def plot_pqtl_repro(vals, nsds, v1=0.15, v2=0.2, rv=0.1, sd_cutoff=8):
    I = SP.where((abs(vals[:,1]) > 1e-16) & (nsds[:,0] > SP.log2(sd_cutoff)))
    vals, nsds = vals[I], nsds[I] # skip ones that can't ascertain replication for
    I0 = SP.where((abs(vals[:,1]) <= v1) | (vals.prod(axis=1) <= 0) | (nsds[:,1] <= SP.log2(sd_cutoff)))[0]
    I1 = SP.where((abs(vals[:,1]) > v1) & (vals.prod(axis=1) > 0) & (nsds[:,1] > SP.log2(sd_cutoff)))[0]
    I2 = SP.where((abs(vals[:,1]) > v2) & (vals.prod(axis=1) > 0))[0]
    Inorep = SP.where((abs(vals[:,1]) < rv) | (vals.prod(axis=1) <= 0))[0]
    Irep = SP.where((abs(vals[:,1]) >= rv) & (vals.prod(axis=1) > 0))[0]
    PL.figure(figsize=(9,4))
    xv = SP.arange(0.15,0.5,0.01)
    yv = []
    for x in xv: yv.append([sum(abs(vals[Irep,0]) > x), sum(abs(vals[:,0]) > x)])
    yv = 1.*SP.array(yv)
    PL.subplot(121)
    PL.plot(xv,yv[:,0], lw=4, alpha=0.7)
    PL.plot([0.2,0.2],[0,max(yv[:,0])], 'r--', alpha=0.7)
    PL.xlabel("QTL call cutoff")
    PL.ylabel("Number of QTLs")
    PL.subplot(122)
    yp = yv[:,0]/yv[:,1]
    PL.plot(xv,yp, lw=4, alpha=0.7)
    PL.plot([0.2,0.2],[yp.min(),1], 'r--', alpha=0.7)
    PL.xlabel("QTL call cutoff")
    PL.ylabel("Replication rate")
    
    PL.figure(figsize=(12,10))
    colors = 'gbk'
    PL.plot(vals[Inorep,0], vals[Inorep,1], "r.", alpha=0.2, markersize=15)
    for i,I in enumerate([I0,I1,I2]):
        J = SP.intersect1d(I,Irep)
        PL.plot(vals[J,0], vals[J,1], colors[i] + ".", markersize=15, alpha=0.2)
    x = v2
    xm = 0.8
    ym = 0.8
    y = rv
    PL.plot([-xm,xm], [y,y], "r-")
    PL.plot([-xm,xm], [-y,-y], "r-")
    PL.plot([-x,-x], [-ym,ym], "r-")
    PL.plot([x,x], [-ym,ym], "r-")
    PL.xlim(-xm,xm)
    PL.ylim(-ym,ym)
    PL.xlabel("Discovery signal")
    PL.ylabel("Average replication signal")
    PL.show()


""" Output all discovery set QTLs and their replication statistics """
def output_qtl_replication_stats(sample_set="all", af_stringent=0.15, sd_stringent=5, qtl_region_span=1000, outfilename="%s/qtl/replication.tab"%DATA_DIR):
    LOG.debug("Outputting QTL replication stats. af_stringent=%.2f, sd_stringent=%.1f, qtl_region=%d, outfilename=%s"%(af_stringent, sd_stringent, qtl_region_span, outfilename))
    meta = get_all_seq_metadata()
    vals, nsds, pvals, repcounts = get_pqtlset_repro(sample_set, af_stringent, sd_stringent, qtl_region_span)
    ofh = file(outfilename, 'w')

    for qtl in vals:
        gene = meta[SP.where(meta[:,2] == qtl[0])[0][0],4]
        qtl_genes = ""
        if len(qtl) > 11: qtl_genes = qtl[11]

        if qtl[0].count("a") > 0 and abs(float(qtl[8])) > 0.4: # if array QTL, denote that the "AF" differences are really not; also apply extra filter
            ofh.write(gene + "\t" + "\t".join(qtl[1:8]) + '\t"' + qtl[8] + '"\t' + "\t".join(qtl[9:11]) + "\t%.3f\t%.1f\t%.4f\t"%(vals[qtl][1], nsds[qtl][1], pvals[qtl]) + "%s\n"%(qtl_genes))
        else:
            ofh.write("\t".join([gene] + list(qtl[1:11])) + "\t%.3f\t%.1f\t%.4f\t"%(vals[qtl][1], nsds[qtl][1], pvals[qtl]) + "%s\n"%(qtl_genes))
    ofh.close()



""" This is only used in one visualisation
@return map callset->[(discovery AF, replication AF) for replication in replication sets], map callset->[(discovery number of SDs, replication number of SDS)]
"""
def calc_qtlset_replication_vals():
    rvals, rnsds = {},{}
    meta = get_all_seq_metadata()    
    callsets = SP.unique(meta[:,-1])
    set_qtl_rep = calc_qtl_replication_stats(sample_set, af_stringent, sd_stringent, qtl_region_span)

    for set in SP.unique(meta[:,2]):
        rvals[set], rnsds[set] = [],[]
        if (set in callsets) or (set not in set_qtl_rep) or (set.count("a") > 0): continue
        cs = meta[SP.where(meta[:,2] == set)[0][0],-1] # callset value
        if cs.count("a") > 0: continue # skip arrays for now
    
        for qtl in set_qtl_rep[cs]:
            for rep in set_qtl_rep[cs][qtl]:
                if rep[1] == set:
                    rvals[set].append([rep[2], rep[3]])
                    rnsds[set].append([rep[6], rep[7]])
        rvals[set] = SP.array(rvals[set])
        rnsds[set] = SP.log2(SP.array(rnsds[set]))
        
    return rvals, rnsds



 
def main():
    output_qtl_replication_stats(sample_set="all", af_stringent=0.15, sd_stringent=5, qtl_region_span=1000)
    return
    plot_pqtl_repro(*get_pqtl_repro())
    for qtl_region_span in [1000]: #[10000, 30000]:
        calc_all_sample_pval(qtl_region_span=qtl_region_span)
    return


if __name__ == '__main__':
    main()
