from parts2014_gfpvar.tools.common import *
from parts2014_gfpvar.tools.util import get_glucose_smith_qtls_peakavoid
from parts2014_gfpvar.tools.io import *
from parts2014_gfpvar.analysis.array.util import *
from parts2014_gfpvar.analysis.seq.util import *
import scipy as SP
import pylab as PL
import numpy as NP
import scipy.stats as ST 
import glob
import pdb


def read_qtls(seq_set, af_lenient=0.1, sd_lenient=3, af_stringent=0.15, sd_stringent=5, length_cutoff=0):
    if seq_set == "all":
	os.system("cat %s/qtl/*?R*afl-%.2f_afs-%.2f_sdl-%.1f_sds-%.1f_len-%d.tab | grep -vh Set > %s/qtl/all_afl-%.2f_afs-%.2f_sdl-%.1f_sds-%.1f_len-%d.tab"%(DATA_DIR, af_lenient, af_stringent, sd_lenient, sd_stringent, length_cutoff, DATA_DIR, af_lenient, af_stringent, sd_lenient, sd_stringent, length_cutoff))
    ifh = file("%s/qtl/%s_afl-%.2f_afs-%.2f_sdl-%.1f_sds-%.1f_len-%d.tab"%(DATA_DIR, seq_set, af_lenient, af_stringent, sd_lenient, sd_stringent, length_cutoff), 'r')
    d = SP.array([l.strip("\n").split("\t") for l in ifh][1:])
    return d
 
""" Get the list of sample pairs to call QTLs from
@param return_singletons whether to return "pairs" where only one population is measured
@param sequencing if True, returns all sequencing sample pairs, otherwise array sample pairs
@return map pair name-> [sample_high, sample_low, assay], where assay is in "Sequencing", "Array"
"""
def get_qtl_sample_pairs(return_singletons=False):
    result = {}
    sample_metadata = get_all_metadata(sample_set="qtl") # map of sample name -> data
    pair_ids = SP.unique(SP.array(sample_metadata.values())[:,1]) # 2nd column (1-based index) is the index of the matching pair

    for pair in pair_ids: # for each pair name in metadata
        vals = [None, None, None] # samples for high and low population, and assay used
        for sample in sample_metadata: # go through list of samples
            meta = sample_metadata[sample] 
            if meta[1] == pair: # that match this pair
                if meta[2] == "GFP_high":  vals[0] = sample # store the sample names
                elif meta[2] == "GFP_low": vals[1] = sample
                vals[2] = meta[4] # and assay type
                    
        if (None not in vals) or return_singletons: # Store the samples matching this pair if high and low present, or don't care if both there
            result[pair] = vals

    return result


""" Get metadata for all samples from file seq/samples/meta.tab
@param sample_set set of samples to consider (from arrays, sequencing, or QTL mapping (that includes both))
@return Nx6 SP.array, with columns Sample ID, well_population, sample set ID, GFP population, ORF, assay type
"""
def get_all_seq_metadata():
    return SP.array([l.strip().split("\t") for l in file('%s/seq/samples/meta.tab'%(DATA_DIR),'r')][1:])


""" Retrieves data for a sequencing sample
@param sample sequencing sample to query
@param sample_set set of samples to consider (from arrays, sequencing, or QTL mapping (that includes both))
@return list of metadata items from seq/samples/meta.tab if matching sample found there, None otherwise"""
def get_sample_metadata(sample, sample_set="qtl"):
    return get_all_metadata(sample_set=sample_set)[sample]


""" Get data for QTL mapping for a given sample.
@param sample sample to get data for
@param type of assay used for the sample
@return tuple of (ML allele frequency, posterior allele frequency, posterior params, bad_locus?, loci, coverage """
def get_sample_data(sample, assay):
    if assay == "Array":
        return read_array_data_for_qtlcall(sample)
    else:
        return get_sample_seq_data(sample)


""" Read the QTLs for a given set
@param set_name name of the set to read (e.g. A19_R1)
@return list of QTLs, each list member a tuple (set, chromosome, peak coordinate, af change, sd(af change), peak start, peak end """
def read_qtl(set_name, af_cutoff, sd_cutoff, length_cutoff):
    ifh = file("%s/qtl/%s_af-%.3f_sd-%.3f_len-%d.tab"%(DATA_DIR, set_name, af_cutoff, sd_cutoff, length_cutoff), 'r')
    ifh.next()
    chrnames = get_seq_chr_names()
    qtls = []

    for l in ifh:
        d = l.strip("\n").split("\t")
        chrm = [c for c in chrnames if chrnames[c] == d[3]][0]
        qtls.append(d[0:3] + [chrm] + [int(d[4])] + map(float, d[5:7]) + map(int, d[7:9]))
    return qtls


""" Write a list of decorated QTLs to the correct output file
@param set_name name of the set (used in output file name construction)
@param qtls list of QTLs, each member a list of set_name, sample_high, sample_low, chrm, peak, delta, sd_delta, peak start, peak end
@return None 
@effect outputs QTL tab file. , [peak genes sorted by nearness to peak, with a "*" if having a deleterious mutation]) """
def output_qtl(set_name, af_cutoff, sd_cutoff, length_cutoff, qtls, n_genes=20):
    chrnames = get_seq_chr_names()
    chrgenes = read_chr_genes()
    gene_snps = read_gene_snps()
    ofh = file("%s/qtl/calls/%s_af-%.3f_sd-%.3f_len-%d.tab"%(DATA_DIR, set_name, af_cutoff, sd_cutoff, length_cutoff), 'w')
    ofh.write("Chrm\tPeak_location\tAF(GFP_high)-AF(GFP_low)\tSD(af_change)\tPeak_start\tPeak_end\tQTL_type\tGenes\n")
    for qtl in qtls:
        set_name, sample_high, sample_low, chrm, peak, delta, sd_delta, start, end = qtl
	if chrm in chrnames: chrm = chrnames[chrm] # if sequencing QTLs, chromosome names are messed up.
        ofh.write("%s\t%s\t%s\t%s\t%d\t%.3f\t%.3f\t%d\t%d"%(set_name, sample_high, sample_low, chrm, peak, delta, sd_delta, start, end))
        gene_name = get_sample_metadata(sample_high)[3]
        ofh.write("\t%s"%(get_qtl_type(chrm, peak, gene_name, chrgenes[chrm])))

        sep = "\t"
        for gene, common_name in get_sorted_region_genes(chrgenes[chrm], peak)[0:n_genes]:
            decoration = ""
            snps = []
            if gene in gene_snps: snps = gene_snps[gene]
            for snp in snps:
                if (snp[4] == "TOLERATED") and (decoration == ""): decoration = "="
                if snp[4] == "DELETERIOUS": decoration = "*"
            if str(common_name) != "None": gene = common_name
            ofh.write("%s%s%s%s"%(sep, decoration, gene, decoration))
            sep = ","
        ofh.write("\n")
    ofh.close()



""" Read SNPs for each gene
@return map gene->list of SNPs, each member a tuple (chromosome, location, SNP type, aa change, SIFT decision, SIFT score). If SNP is synonymous, last 3 items are "-" """
def read_gene_snps():
    ifh = file("%s/seq/ref/thinned_mutations_withsift.tab"%DATA_DIR, 'r')
    result = {}
    for l in ifh:
        chrm, loc, gene, common_name, snp_type, aa_change, sift, sift_score = l.strip().split()
        if gene not in result: result[gene] = []
        result[gene].append([chrm, int(loc), snp_type, aa_change, sift, sift_score])
    return result



""" Read all genes in a chromosome
@return map chrm->[genes], each item in the list a tuple (start, end, common_name, orf_name) """
def read_chr_genes():
    chr_names = get_seq_chr_names()
    ifh = file("%s/ref/scer.gff"%DATA_DIR,'r')
    gene_hash = {'mitochondrion':[]}
    
#chr01   SGD     gene    538     792     .       +       .       ID=YAL068W-A;Name=YAL068W-A;Ontology_term=GO:0003674,GO:0005575,GO:0008150;Note=Dubious%20open%20reading%20frame%20unlikely%20to%20encode%20a%20protein%3B%20identified%20by%20gene-trapping%2C%20microarray-based%20expression%20analysis%2C%20and%20genome-wide%20homology%20searching;dbxref=SGD:S000028594;orf_classification=Dubious
    for line in ifh: # for each line
        d = line.strip().split("\t")
        if (len(d) < 9) or (d[2] != "gene"): continue # that encodes a gene
        chrm, start, end, type, strand, meta = SP.array(d)[[0, 3, 4, 2, 6, 8]] # get the data

        common_name, sgd_orf = None, None # init common name and sgd orf name
        for param in meta.split(';'): # go through parameter list
            var, val = param.split('=')
            if var == 'gene':   common_name = val # update common and orf names if the data are there
            elif var == 'Name': sgd_orf = val
                
        if chrm[3:] not in gene_hash: gene_hash[chrm[3:]] = [] # finally, store the info on the gene
        gene_hash[chrm[3:]].append([int(start), int(end), common_name, sgd_orf])
        
    return gene_hash # return gene locations and names for each chromosome



""" Get a list of genes that are ordered by their distance to the peak
@param chr_genes list of genes on the chromosome (start, end, name1, name2)
@param peak_center location to measure distances to
@return [gene name] where for any two consecutive gene names in list, the first gene is closer to peak_center than the next. The name is common name if available, otherwise systematic ORF name """
def get_sorted_region_genes(chr_genes, peak_center, region_start=0, region_end=1e9):
    genes, dists = [], []

    for (start, end, common_name, orf_name) in chr_genes: # for each gene on chromosome
        if end < region_start or start > region_end: continue # if no overlap (gene totally to the left or right), go on
        dists.append(abs(0.5*(start + end) - peak_center)) # store distance to peak center
        if common_name is not None: genes.append(common_name) # and name of the gene
        else: genes.append(orf_name)
    I = SP.argsort(dists) # sort the genes by their closeness
    return SP.array(genes)[I] # return gene names



""" Return the type of the QTL (SGA related, known peak, or other)
@param chrm chromosome of the QTL
@param peak location of the QTL
@param gene name name of the gene QTLs were mapped for
@param chr_genes list of genes on the QTL chromosome (start, end, name1, name2)
@return type of the QTL - one of ? (no known info), IRA2/MKT1/HAP1 (one of known peaks), SGA-xxxx (selection marker related), or rDNA (a chrXII region) """
def get_qtl_type(chrm, peak, gene_name, chr_genes):
    if chrm == "XII" and abs(peak - 470000) < 25000: return "rDNA"
    if chrm == "V" and abs(peak - 32000) < 10000: return "SGA-CAN1"
    if chrm == "II" and abs(peak - 235000) < 10000: return "SGA-HTA2"
    if chrm == "III" and abs(peak - 200000) < 30000: return "SGA-MAT"
    if chrm == "XIV" and abs(peak - 468000) < 10000: return "MKT1"
    if chrm == "XV" and abs(peak - 176000) < 10000: return "IRA2"
    if chrm == "XII" and abs(peak - 649000) < 10000: return "HAP1"
    for (start, end, name1, name2) in chr_genes:
        if ((gene_name == name1) or (gene_name == name2)) and abs(peak - 0.5*(start+end)) < 10000: return "SGA-selfGFP"
    return "?"



def get_locus_overlap(l1,l2):
    all_loc = set(l1) & set(l2)
    res_loc = []
    I = [[],[]]
    for i,l in enumerate(l1):
        if l in all_loc:
            res_loc.append(l)
            I[0].append(i)
            I[1].append(SP.where(l2 == l)[0][0])
    return res_loc, I


def get_cumulative_chrm_lengths():
    result = {}
    chrnames = get_seq_chr_names()
    total = 0
    diffs = cl("%s/seq/diffs/X1_sR1_rec-0.90.pickle"%DATA_DIR)
    for chrm in sorted(diffs.keys()):
        c = chrnames[chrm]
        result[c] = total
        total += max(diffs[chrm][0])
    result["all"] = total
    return result




def get_all_mrna_qtls(smoother=25000):
    ROMAN_NAMES = "I II III IV V VI VII VIII IX X XI XII XIII XIV XV XVI".split()
    qtls = [l.strip().split("\t") for l in file("%s/mrna/smith_linkage_results_thinned_distant.tab"%DATA_DIR, 'r').readlines()[1:]]
    chrlens = get_cumulative_chrm_lengths()
    vals = SP.zeros(int(max(chrlens.values())/smoother) + 1)
    for (orf, condition, intchrm,lod,peak,loc, by, rm) in qtls:
        cloc = chrlens[ROMAN_NAMES[int(intchrm)-1]] + int(loc)
        start, end = cloc - smoother/2, cloc + smoother/2
        vals[(start/smoother):(end/smoother + 1)] += 1
    return vals



def get_orf_mrna_qtls():
    ROMAN_NAMES = "I II III IV V VI VII VIII IX X XI XII XIII XIV XV XVI".split()
    qtls = [l.strip().split("\t") for l in file("%s/mrna/smith_linkage_results_thinned_distant_sortedorfs.tab"%DATA_DIR, 'r').readlines()[1:]]
    result = {}
    for (orf, condition, intchrm,lod,peak,loc, by, rm) in qtls:
        if orf not in result: result[orf] = []
        chrm = ROMAN_NAMES[int(intchrm)-1]
        result[orf].append([chrm, int(loc), float(lod), float(rm)-float(by)]) # RM allele effect
    return result



def get_average_selection_signal(orf, chrm, loc, rec_cutoff=0.90):
    chrnames = get_seq_chr_names()

    meta = get_all_seq_metadata()
    I = SP.where((meta[:,4] == orf) & (meta[:,3] == "GFP_low"))[0]
    sets = SP.unique(meta[I,2])
    signals = []

    for s in sets:
        if s.count("a") > 0: continue
        d = cl("%s/seq/diffs/%s_rec-%.2f.pickle"%(DATA_DIR, s, rec_cutoff))
        seq_chrname = [c for c in chrnames if chrnames[c] == chrm][0]
        i = SP.argmin(abs(d[seq_chrname][0] - loc))
        signals.append([d[seq_chrname][1][i], d[seq_chrname][2][i]**2])
    return list(SP.array(signals).mean(axis=0))


def create_expr_data(filename):
    ifh = file("%s/mrna/pbiol_segregants.tab"%DATA_DIR, 'r')
    expr_inds = ifh.next().strip("\n").split("\t")[2:]
    data = SP.array([l.strip("\n").split("\t") for l in ifh])
    data[SP.where(data == "")] = "nan"
    meta = SP.array(data[:,0:2])
    expr = SP.array(data[:,2:], float)
    cdm((meta, expr, expr_inds), filename)
    return meta, expr, expr_inds


def read_mrna_data(condition="glucose"):
    snploc = cl("%s/mrna/snploc.pickle"%DATA_DIR)
    snps = cl("%s/mrna/snps_segregants.pickle"%DATA_DIR)
    inds = cl("%s/mrna/individuals_segregants.pickle"%DATA_DIR)
    snp_inds = SP.array([x.strip().split("||")[-1].strip().lower().split(" ")[0:2] for x in inds])
    Isnp = [i for i in range(len(snp_inds)) if snp_inds[i,0].lower()[-1] == condition[0]]
    snps = snps[Isnp]

    exprfile = "%s/mrna/expr_%s.pickle"%(DATA_DIR, condition)
    meta, expr, expr_inds = cl(exprfile) if os.path.exists(exprfile) else create_expr_data(exprfile)
    expr_inds = [x.strip().split(" ")[0].lower() for x in expr_inds]
    Iexpr = [expr_inds.index(i) for i in snp_inds[Isnp,0]]
    orfs = [x.split("||")[-2].strip() for x in meta[:,1]]
    expr_map = {}
    for o in range(len(orfs)):
        expr_map[orfs[o]] = expr[o][Iexpr]
    return snps, snploc, expr_map

def nanmean(x):  return SP.nansum(x)/(sum(~SP.isnan(x)))


def mlpred(X,Z, debug=False):
    phiI = NP.linalg.pinv(X)
    wml = SP.dot(phiI.T,Z).T
    if debug:
        print "F:", phiI
        print "W:", wml
        print "X:", X
        print "Z:", Z
    return SP.dot(wml, X)


def ml_weight(y, x, cov_geno, geno, geno_loc):
    #print y,x,cov_geno,geno,geno_loc
    x = [x, SP.ones(geno.shape[0])] # add column of ones to correct for mean
    if cov_geno is None: cov_geno = []
    for (chrm, loc, lod) in cov_geno:        
        Ichrm = SP.where(geno_loc[:,0] == chrm)[0]
        i = SP.argmin(abs(geno_loc[Ichrm,1] - loc))
        x.append(geno[:,Ichrm[i], 0])
    phiI = NP.linalg.pinv(SP.array(x))
    return SP.dot(phiI.T,y).T


def subtract_covariates(y, covs, geno, geno_loc, debug=False):
    x = [SP.ones(geno.shape[0])]
    for (chrm, loc, lod) in covs:        
        Ichrm = SP.where(geno_loc[:,0] == chrm)[0]
        i = SP.argmin(abs(geno_loc[Ichrm,1] - loc))
        x.append(geno[:,Ichrm[i], 0])
    if debug:
        p = mlpred(SP.array(x), y, debug=True)
        print "Original:", y
        print "Pred:", p
        print "Residual", y-p
        print "Correlation:", SP.corrcoef(y-p, y)[0,1]
    return y - mlpred(SP.array(x),y)


def get_mrna_effect_signal(orf, chrm, loc, covariates=None, mrna_data=None, by=True):
    return get_mrna_effect_signal_multipeak(orf, [(chrm,loc)], covariates, mrna_data, by=by)


def get_mrna_effect_signal_multipeak(orf, locs, covariates=None, mrna_data=None, andor=0, by=True):
    ROMAN = "0 I II III IV V VI VII VIII IX X XI XII XIII XIV XV XVI".split()
    if mrna_data is None: mrna_data = read_mrna_data()
    geno, geno_loc, expr = mrna_data
    I0, I1 = SP.ones(len(geno),bool), SP.zeros(len(geno), bool)
    if andor == 1: I0, I1 = SP.zeros(len(geno),bool), SP.ones(len(geno), bool)
    
    if orf not in expr: return SP.nan
        
    for chrm, loc in locs:
        intchrm = ROMAN.index(chrm)
        Ichrm = SP.where(geno_loc[:,0] == intchrm)[0]
        i = SP.argmin(abs(geno_loc[Ichrm,1] - loc))
        if andor == 0:
            I0 = I0 & (geno[:,Ichrm[i], 0] == 0) # BY
            I1 = I1 | (geno[:,Ichrm[i], 0] == 1) # RM
        else:
            I0 = I0 | (geno[:,Ichrm[i], 0] == 0) # BY
            I1 = I1 & (geno[:,Ichrm[i], 0] == 1) # RM

    wml = ml_weight(expr[orf], I0, covariates, geno, geno_loc) # I0 is the BY allele - weight of it gives the BY allele effect
    #if orf in ["None", "YER053C-A"]:  print "multipeak mrna", orf, andor, sum(I0), sum(I1), covariates, locs, wml[0]
    if by: return wml[0] # if asked to return BY allele effect - this is it
    else: return -wml[0] # the RM allele effect is naturally the opposite of the BY one


def get_mrna_signals(peaks=["HAP1"], mrna_data=None, andor=1, plate_orfs=None, correct_qtls=True):
    if mrna_data is None: mrna_data = read_mrna_data()
    expr = mrna_data[2]
    orf_qtls = get_glucose_smith_qtls_peakavoid(avoid_peaks=peaks, lod_cutoff=5)
    peak_locs = {"HAP1":("XII",665000), "MKT1":("XIV", 486861), "IRA2":("XV", 170945)}
    peak_loc = [peak_locs[peak] for peak in peaks]
    mrna_signals = []
    
    for o in plate_orfs:
        if o not in expr:
            mrna_signals.append(SP.nan)
            continue
        qtls = None
        if (o in orf_qtls) and correct_qtls: qtls = orf_qtls[o]
        mrna_signals.append(get_mrna_effect_signal_multipeak(o, peak_loc, covariates=qtls, mrna_data=mrna_data, andor=andor))

    return SP.array(mrna_signals)



def get_all_mrna_effect_signal(chrm, loc, mrna_data=None):
    if mrna_data is None: mrna_data = read_mrna_data()
    geno, geno_loc, expr = mrna_data
    Ichrm = SP.where(geno_loc[:,0] == chrm)[0]
    i = SP.argmin(abs(geno_loc[Ichrm,1] - loc))
    I0 = SP.where(geno[:,Ichrm[i], 0] == 0)[0] # I0 = "BY"
    I1 = SP.where(geno[:,Ichrm[i], 0] == 1)[0] # I1 = "RM"
    result = {}
    for orf in expr:
        if orf == "YER053C-A":
            pass # print "RM", nanmean(expr[orf][I1]), "BY", nanmean(expr[orf][I0])
        result[orf] = nanmean(expr[orf][I1]) - nanmean(expr[orf][I0]) # RM - BY
    return result
