import scipy as SP
import pylab as PL
from parts2014_gfpvar.tools.common import *
from util import *
import pdb

""" Calculate heritability based on difference in means and individual variances of two equally-sized populations
@param delta difference in means, mu1-mu2
@param sdp standard deviation of the parental (clonal) population
@param sdk standard deviation of the kid (segregating) population
@return fraction of total variation in the combined population (50% kids 50% parent) due to genotype.
"""
def simple_h2(delta, sdp, sdk, prior=None, prior_type="additive", mu=None):
    if prior is not None: # if prior from genome-wide distribution of variances is given, use it. Expect a value for SD parents and SD kids
        if prior_type == "additive":
            sdp = ((2*sdp*sdp + prior[0]*prior[0])/3.)**0.5
            sdk = ((2*sdk*sdk + prior[1]*prior[1])/3.)**0.5
        elif prior_type == "multiplicative":
            #print sdp**2, (mu[0]*(10**prior[0]))**2
            #print sdp, ((sdp**4)*((mu[0]*(10**prior[0]))**2))**(1./6)
            sdp = ((sdp**4)*((mu[0]*(10**prior[0]))**2))**(1./6) # average variances in log scale, further take root to get SD. mu*(10**prior) is the expected magnitude of the change from prior
            sdk = ((sdk**4)*((mu[1]*(10**prior[1]))**2))**(1./6)
    total_var = (0.5*(sdp**2 + sdk**2) + (0.5*delta)**2) 
    if sdp > sdk: # parents more diverse than kids
        return ((0.5*delta)**2)/total_var # fraction of total *change* in variance due to genotype; assume any difference in variances between populations is not real
    else: # kids more diverse than parents
        frac_clonal = min(1, (sdp**2)/total_var)
        return 1. - frac_clonal # 1. - fraction of variance that is due to private variation (clone-clone)


def estimate_h2_gamma_params(repvar, n_iter=4):
    a,b = 0,0 # start off uninformed
    tau = SP.zeros(len(repvar), float)
    print sum(repvar < 1e-4), ST.nanmean(repvar)
    repvar = repvar + ST.nanmean(repvar)/1000. # add 0.1%, removes near-0s.
    import pylab as PL
    PL.figure()
    pi = 2
    PL.subplot(2,3,1)
    PL.hist(SP.log10(repvar), bins=40, range=(-6,1))
    for i in range(n_iter):
        tau = (2*b + 2./repvar)/(2*a + 1)
        mu, v = ST.nanmean(tau), ST.nanstd(tau)**2 # calculate moments
        a, b = mu*mu/v, mu/v # update gamma parameters to match moments
        PL.subplot(2,3,pi); pi+=1
        PL.hist(-SP.log10(tau)/2,bins=40, range=(-6,1))
        print 1./mu, v, a, b, sum(1./tau < 1e-4)
    PL.subplot(2,3,pi); pi+=1
    tau = (1./ST.nanmean(repvar) + 2./repvar)/(0.5 + 1)
    PL.hist(-SP.log10(tau)/2,bins=40, range=(-6,1))
    PL.show()
    pdb.set_trace()
    return a,b


def calc_priors(data, I):
    #Ifeature = 3 + SP.array([1,4,7,10,27,30,33,36,52,53,58,59]) # The corresponding features are output below; ought to correspond to
    mse_by_logs, mse_rm_logs, mse_by_lins, mse_rm_lins, s_by_lins, s_rm_lins, s_by_logs, s_rm_logs = [],[],[],[],[],[],[],[]
    for i in range(1, len(data)):
        mu_by_lin, mu_rm_lin, mse_by_log, s_by_log, mse_by_lin, s_by_lin, mse_rm_log, s_rm_log, mse_rm_lin, s_rm_lin,delta_log, mse_delta_log, delta_lin, mse_delta_lin = map(float, data[i][I])
        mse_by_logs.extend([mse_by_log])
        mse_by_lins.extend([(SP.log10(mse_by_lin/mu_by_lin))])
        s_by_logs.extend([s_by_log])
        s_by_lins.extend([(SP.log10((s_by_lin**0.5)/mu_by_lin))]) # the s's are variances 
        mse_rm_logs.extend([mse_rm_log])
        mse_rm_lins.extend([(SP.log10(mse_rm_lin/mu_rm_lin))])
        s_rm_logs.extend([s_rm_log])
        s_rm_lins.extend([(SP.log10((s_rm_lin**0.5)/mu_rm_lin))])
    return [SP.mean(mse_by_logs), SP.mean(mse_rm_logs)], [SP.mean(mse_by_lins), SP.mean(mse_rm_lins)],[SP.mean(s_by_logs), SP.mean(s_rm_logs)], [SP.mean(s_by_lins), SP.mean(s_rm_lins)]


""" Output table of heritabilities for the individual ORFs
@param infilename file with parental means and variances combined across multiple replicates
@param outfilename file to output
@return None
@effects creates (or overwrites if exists) outfilename
"""
def create_heritabilities_output_table(infilename="%s/summary/paper/table_S2-combined.tab"%DATA_DIR, outfilename="%s/summary/cyto/table_S3.tab"%DATA_DIR, strongprior=False):
    data = SP.loadtxt(infilename, delimiter="\t", dtype=object)
    Ifeature = 3 + SP.array([6, 32, 1,4,7,10,27,30,33,36,52,53,58,59]) # The corresponding features are output below; ought to correspond to
    # mu_by_lin, mu_rm_lin, mse_by_log, s_by_log, mse_by_lin, s_by_lin, mse_rm_log, s_rm_log, mse_rm_lin, s_rm_lin,delta_log, mse_delta_log, delta_lin, mse_delta_lin
    assignments = "mu_by_lin, mu_rm_lin, mse_by_log, s_by_log, mse_by_lin, s_by_lin, mse_rm_log, s_rm_log, mse_rm_lin, s_rm_lin, delta_log, mse_delta_log, delta_lin, mse_delta_lin".split(" ")
    LOG.debug("Outputting heritabilities from file %s to file %s using features:"%(infilename, outfilename))
    #print len(Ifeature), len(assignments), data[0]
    for i,f in enumerate(Ifeature):
        LOG.debug("%s (expected %s)"%(data[0][f], assignments[i]))
    # 0. Write header
    ofh = file(outfilename, 'w')
    ofh.write("#ORF\tPlate\tWell") 
    for trait in ["Heritability_raw", "Heritability_EBprior", "Mean_difference", "SD(Mean_difference)"]: # Writing h2
	for cells in ["single_cell","population_mean"]:
	    for scale in ["linear", "log"]:
		ofh.write("\t%s(%s,%s_scale)"%(trait, cells, scale))
    ofh.write("\n")
    mse_log_prior, mse_lin_prior, s_log_prior, s_lin_prior = calc_priors(data, Ifeature)
    
    # 1. Write data for each ORF
    for i in range(1, len(data)):
	ofh.write("%s\t%s\t%s"%(tuple(data[i][0:3])))
        # 1.0 Extract relevant features
	mu_by_lin, mu_rm_lin, mse_by_log, s_by_log, mse_by_lin, s_by_lin, mse_rm_log, s_rm_log, mse_rm_lin, s_rm_lin,delta_log, mse_delta_log, delta_lin, mse_delta_lin = map(float, data[i][Ifeature])
        mu = (mu_by_lin, mu_rm_lin)
	# 1.1 Output raw heritabilities in the various scales (linear/log, single cell/population mean)
	ofh.write("\t%.3f"%(simple_h2(delta_lin,s_by_lin**0.5, s_rm_lin**0.5, prior=None))) # Heritabilities - single cell, linear scale
	ofh.write("\t%.3f"%(simple_h2(delta_log, s_by_log**0.5,s_rm_log**0.5, prior=None))) # Heritabilities - single cell, log scale
	ofh.write("\t%.3f"%(simple_h2(delta_lin,mse_by_lin,mse_rm_lin, prior=None))) # and population level, linear scale
	ofh.write("\t%.3f"%(simple_h2(delta_log, mse_by_log,mse_rm_log, prior=None))) # and population level, log scale
	# 1.2 Output emprirical Bayes corrected heritabilities in the various scales (linear/log, single cell/population mean)
	ofh.write("\t%.3f"%(simple_h2(delta_lin,s_by_lin**0.5, s_rm_lin**0.5, prior=s_lin_prior, prior_type="multiplicative", mu=mu))) # Heritabilities - single cell, linear scale
	ofh.write("\t%.3f"%(simple_h2(delta_log, s_by_log**0.5,s_rm_log**0.5, prior=s_log_prior, prior_type="additive", mu=mu))) # Heritabilities - single cell, log scale
	ofh.write("\t%.3f"%(simple_h2(delta_lin,mse_by_lin,mse_rm_lin, prior=mse_lin_prior, prior_type="multiplicative", mu=mu))) # and population level, linear scale
	ofh.write("\t%.3f"%(simple_h2(delta_log, mse_by_log,mse_rm_log, prior=mse_log_prior, prior_type="additive", mu=mu))) # and population level, log scale
        # 1.3 Average levels and standard deviations
	ofh.write("\t%.3f\t%.3f\t%.3f\t%.3f"%(delta_lin, delta_log, delta_lin, delta_log)) # means
	ofh.write("\t%.3f\t%.3f\t%.3f\t%.3f"%((0.5*s_by_lin+0.5*s_rm_lin)**0.5, (0.5*(s_by_log +s_rm_log))**0.5, (0.5*(mse_by_lin + mse_rm_lin))**0.5, (0.5*(mse_by_log + mse_rm_log))**0.5)) # SDs
	ofh.write("\n")
    ofh.close()


""" Get the decomposition of GFP allele swap data (from BY-BY, BY-RM, RM-BY, RM-RM crosses) into cis and trans. Also return filtered data (NaNs for empty wells) and cross names
@param data map of cross->replicate->[8,12] GFP level
@param counts map of cross>replicate->[8,12] number of cells in well for the cross and replicate
@param cutoff maximum number of cells in the well for which the well value is made NaN for all parents
@return data, cis, trans, total, p. data- same map as before, NaNs replaced for lowly populated wells. cis - total average cis effect (length-96 array).  trans - total average trans effect (length-96 array).  total - total average total effect (length-96 array. difference between BYBY and RMRM). 
"""
def get_gfpswap_decomposition(data, counts, cutoff=200):
    for r in range(2):
        for p in data:
            I = SP.where(counts[p][r] < cutoff)
            for p2 in data:
                data[p2][r][I] = SP.nan

    # Quantify relative effects (cis, trans, total)
    p = ["BY-BYgfp","RM-BYgfp","BY-RMgfp","RM-RMgfp"]
    cis = SP.stats.nanmean(data[p[2]] - data[p[1]], axis=0).reshape(96)
    total = SP.stats.nanmean(data[p[3]] - data[p[0]], axis=0).reshape(96)
    #trans = 0.5*(SP.stats.nanmean(data[p[1]] - data[p[0]], axis=0) + SP.stats.nanmean(data[p[3]] - data[p[2]], axis=0)).reshape(96)
    interact = SP.stats.nanmean(data[p[1]] + data[p[2]] - data[p[0]] - data[p[3]], axis=0).reshape(96)
    trans = abs(total - cis)
    return data, cis, trans, interact, total, p

#create_heritabilities_output_table()
