import scipy as SP
from parts2014_gfpvar.tools.common import *
from parts2014_gfpvar.analysis.qtl.comparison import get_peak_comparison, get_eqtl_pqtl_overlap, get_pqtl_eqtl_signal
from parts2014_gfpvar.tools.io import *
from parts2014_gfpvar.analysis.mrna.util import *

def print_rna_protein_correlation(min_cov=20):
    sample_gene_covs = get_sample_mrna()
    cyto, cval = SP.loadtxt("%s/paper/table_S2-combined.tab"%DATA_DIR, delimiter="\t", dtype=object), {}
    for i in range(1, len(cyto)):  cval[cyto[i][0]] = SP.log10(float(cyto[i][9])) # BYxBY
    means = {s:ST.nanmean(sample_gene_covs[s].values()) for s in sample_gene_covs}
    vals, I, bads = [], [], {}
    for o in sample_gene_covs["B2"]:
        if o in sample_gene_covs["B3"] and o in cval:
            if sample_gene_covs["B2"][o] > min_cov and sample_gene_covs["B3"][o] > min_cov:
                vals.append([SP.log10(sample_gene_covs["B2"][o]+0.5) + SP.log10(sample_gene_covs["B3"][o]+0.5), cval[o]])
                I.append(True)
                if abs((SP.log2(sample_gene_covs["B2"][o]+0.5) - SP.log2(means["B2"])) - (SP.log2(sample_gene_covs["B3"][o]+0.5) - SP.log2(means["B3"]))) <= 1:
                    I[-1] = False # not bad only if not discrepant
    vals = SP.array(vals).T
    I0 = SP.where(I)[0]
    I1 = SP.where(~SP.array(I, bool))[0]
    print "mRNA-protein correlation r=%.2f, rho=%.2f"%(ST.pearsonr(vals[1][I1],vals[0][I1])[0], ST.spearmanr(vals[1][I1],vals[0][I1])[0])


def main_numbers():
    t2s = SP.loadtxt("%s/paper/table_S2-singles.tab"%DATA_DIR, delimiter="\t", dtype=object)
    t2i = SP.loadtxt("%s/paper/table_S2-img.tab"%DATA_DIR, delimiter="\t", dtype=object, comments="$")
    t2 = SP.loadtxt("%s/paper/table_S2-combined.tab"%DATA_DIR, delimiter="\t", dtype=object)
    t4 = SP.loadtxt("%s/paper/table_S4.tab"%DATA_DIR, delimiter="\t", dtype=object)
    t5 = SP.loadtxt("%s/paper/table_S5.tab"%DATA_DIR, delimiter="\t", dtype=object)
    t6 = SP.loadtxt("%s/paper/table_S6.tab"%DATA_DIR, delimiter="\t", dtype=object)
    t7 = SP.loadtxt("%s/paper/table_S7.tab"%DATA_DIR, delimiter="\t", dtype=object)

    # Abstract
    print "Abstract\n========"
    print "%d different yeast proteins quantified, %d reproducible and nice"%(len(SP.unique(t2s[1:,4])), len(SP.unique(t2[1:,0])))
    h2ind,h2pop = SP.array(t4[:,3], float), SP.array(t4[:,5],float)
    print "Median h2 single cells=%.3f, population=%.3f, over 50=%d/%d (%.1f%%)"%(SP.median(h2ind), SP.median(h2pop), sum(h2pop >= 0.5), len(h2pop), 100.*sum(h2pop>0.5)/len(h2pop))
    Iqtl = SP.where(abs(SP.array([x[8].replace('"', '') for x in t7], float)) >= 0.2)[0]
    counts = {"0":0,"1":0} # 25 sorted, two had 0, keeping them in to have right median calculation
    for i in Iqtl:
        if t7[i,0] not in counts: counts[t7[i,0]] = 0
        counts[t7[i,0]] += 1
    print "Total %d pQTLs for %d genes, median number found=%d"%(len(Iqtl), len(counts), SP.median(counts.values()))
    print

    # Results - screen
    print "Screen results"
    print "=============="
    i1, i2, cutoff = 9, 35, 0.2
    #print t2[0,i1], t2[0,i2]
    #by,rm = SP.array(t2[1:,i1], float), SP.array(t2[1:,i2], float)
    by,rm = SP.log10(SP.array(t2[1:,i1], float)), SP.log10(SP.array(t2[1:,i2], float))
    d = rm - by
    #I = SP.where((d >= SP.log10(1+cutoff)) | (d <= SP.log10(1-cutoff)))[0]
    #for i in I: print t2[1+i,0]
    n1, n2 = sum(d >= SP.log10(1+cutoff)), sum(d <= SP.log10(1-cutoff))
    print "%d quantified; %d RMxBY %d%% larger than BYxBY, %d RMxBY %d%% lower than BYxBY; total %d (%d%%)"%(len(t2)-1, n1, 100*cutoff, n2, 100*cutoff, n1+n2, 100*(n1+n2)/(len(t2)-1))
    print_rna_protein_correlation()
    print "Single cell linear h2 median=%.1f; population average linear h2 median=%.1f"%(100*SP.median(SP.array(t4[:,3],float)), 100*SP.median(SP.array(t4[:,5],float)))

    # Imaging comparison
    min_pairs = 20
    i_parentkid, i_clone, i_total, i_nparentkid, i_nclone, i_ntotal = [8,17], 10, [10,19,21], [9,18], 11, [11,20,22]
    parental, clonal, total, n_parental, n_clonal, n_total = SP.array(t2i[1:,i_parentkid], float).mean(axis=1), SP.array(t2i[1:,i_clone], float), SP.array(t2i[1:,i_total], float).max(axis=1), SP.array(t2i[1:,i_nparentkid], float).min(axis=1), SP.array(t2i[1:,i_nclone], float), SP.array(t2i[1:,i_ntotal], float).min(axis=1)
    I = SP.where((n_parental >= min_pairs) & (n_clonal >= min_pairs) & (n_total >= min_pairs))[0]
    parental, clonal, total = parental[I], clonal[I], total[I]
    n1,n2,n3, n = sum(parental <= clonal), sum(parental <= total), sum(clonal <= total), len(parental)
    print "%d of %d (%.1f%%) parent-kid variance smaller than clone"%(n1, n, 100*n1/n)
    print
    
    # Cis effects
    I = SP.where(SP.array(map(len, t5[:,0])) > 6)[0]
    cis = SP.array(t5[I][:,1:], float)
    cis = cis[SP.where(~(SP.isnan(cis).any(axis=1)))[0],:]
    n1 = sum(cis[:,5:7].mean(axis=1) - cis[:,7:9].mean(axis=1) > SP.log10(1 + cutoff))
    n2 = sum(cis[:,5:7].mean(axis=1) - cis[:,7:9].mean(axis=1) < SP.log10(1 - cutoff))
    ncistrans = sum(abs(cis[:,1]) > abs(cis[:,2]))
    print "%d cis measurements made, %d greater than %d%% difference (%d positive, %d negative), %d cis greater than trans"%(len(cis), n1+n2, 100*cutoff, n1, n2, ncistrans)
    print

    # Sequencing coverage
    print "pQTL mapping"
    print "============"
    print "Median sequencing coverage %dx"%(SP.median(SP.array(t6[:,9],int)))

    # Hxt2 level
    i = SP.where(t2[:,0] == "YMR011W")[0][0]
    by,rm = float(t2[i,9]), float(t2[i,35])
    print "Hxt2 - average level in BYxBY cross = %d, in RMxBY cross = %d, %.1f%% difference"%(by, rm, 100*rm/by - 100)

    # pQTL numbers
    print "Total %d pQTLs for %d genes, median number found=%d, min=%d, max=%d"%(len(Iqtl), len(counts), SP.median(counts.values()), min(counts.values()), max(counts.values()))

    # Peak overlap
    get_peak_comparison(debug=True, signal_cutoff=0.15, smoother=25000, peak_cutoff_pqtl=4, peak_cutoff_eqtl=SP.log2(30), paper_numbers=True)

    # Individual eQTL-pQTL overlap
    print "\npQTL-eQTL concordance"
    print "====================="
    epq = get_eqtl_pqtl_overlap(nearby=0, debug=False)
    epq_cutoff = 0.1
    print "%d eQTLs with LOD>5; %d with pQTL of effect at least %.2f in same direction\n\n\n"%(len(epq), sum((abs(epq[:,1]) >= epq_cutoff) & (epq[:,0]*epq[:,1] > 0)), epq_cutoff)

    # pQTL-eQTL overlap
    get_pqtl_eqtl_signal(by=False, debug=False)



def mrna_enrichment():
    sample_gene_covs = get_sample_mrna()
    means = {}
    for s in sample_gene_covs: means[s] = ST.nanmean(sample_gene_covs[s].values())
    corrfactor = means["X2"]*means["X3"]/(means["B2"]*means["B3"])
    log10corrfactor = SP.log10(corrfactor)
    vals, orfs, min_cov = [], [], 20
    p11,p12,p21,p22 = "B2","B3","X2","X3"
    for o in sample_gene_covs[p11]:
        if o in sample_gene_covs[p12] and o in sample_gene_covs[p21] and o in sample_gene_covs[p22]:
            if sample_gene_covs[p11][o] > min_cov and sample_gene_covs[p12][o] > min_cov and sample_gene_covs[p21][o] > min_cov and sample_gene_covs[p22][o] > min_cov:
                m11, m12 = SP.log2(sample_gene_covs[p11][o]+0.5) - SP.log2(means[p11]), SP.log2(sample_gene_covs[p12][o]+0.5) - SP.log2(means[p12])
                m21, m22 = SP.log2(sample_gene_covs[p21][o]+0.5) - SP.log2(means[p21]), SP.log2(sample_gene_covs[p22][o]+0.5) - SP.log2(means[p22])
                if (abs(m11 - m12) < 1) and (abs(m21 - m22) < 1):
                    vals.append(0.5*(m11+m12-m21-m22))
                    orfs.append(o)
    I = SP.argsort(vals)
    n_orfs = 100
    print "RM larger"
    for i in I[0:n_orfs]:
        print orfs[i]#, vals[i]
    print "\n\nBY larger"
    for i in I[::-1][0:n_orfs]:
        print orfs[i]#, vals[i]
    

def main():
    main_numbers()


if __name__ == '__main__':
    main()
