from parts2014_gfpvar.tools.io import *
from parts2014_gfpvar.analysis.qtl.util import *
from parts2014_gfpvar.analysis.array.util import *
from parts2014_gfpvar.analysis.array.infer import *
import glob
import pdb

ROMAN_CHR_NAMES = "0 I II III IV V VI VII VIII IX X XI XII XIII XIV XV XVI".split()


class cArrayQtl:
    def __init__(self, l):
        d = l.strip().split("\t")
        self.sample = d[0]
        self.qtl_set = d[1]
        self.chrm = d[2]
        self.peak_loc = int(d[3])
        self.peak_score = float(d[4])
        self.peak_width = int(d[5])
        self.peak_start = int(d[6])
        self.peak_end = int(d[7])
        self.peak_genes = d[8].split()


    def __str__(self):
        pass


""" Call QTLs for a given dataset. All regions of length at least length_threshold and difference of at least diff_threshold will be 
@param dataset name of dataset to call qtls for
@param sample_pattern file pattern whose matches will be analysed
@param diff_threshold minimum absolute value difference between sample and reference to be considered for a QTL
@param length_threshold minimum basepair length of a genome segment that must be at least diff_threshold different from reference to be called a QTL
@effects creates QTL call files for selection in high and low GFP populations, as well as the difference between the two.
"""
def create_qtls(dataset="2013-02-22_round1", sample_pattern="*", diff_threshold=0.2, length_threshold=30000, save_signal=True):
    geneloc, genename = get_gene_chr_hash()
    meta = get_all_seq_metadata()
    sets = ["plus_diploid", "minus_diploid", "plus_minus"] # 3 sets of QTLs will be called - high gfp vs diploid hybrid, low gfp vs diploid hybrid, high gfp vs low gfp
    for s in sets: os.system("mkdir -p %s/arrays/qtl/%s/%s_diff"%(DATA_DIR, dataset, s))
    
    # for each sample matching the given pattern
    for f in glob.glob("%s/arrays/condensed/%s/*%s*.tab"%(DATA_DIR, dataset, sample_pattern)):
        s1 = f.split("/")[-1][0:-4] # get the + and - samples        
        samples = [s1, s1.replace("plus", "minus")] # This should now be something like ["B8_plus", "B8_minus"]
        sids = [get_sample_id(s, dataset) for s in samples] # This should be long sample IDs
        setname = meta[SP.where(meta[:,0] == sids[0])[0][0], 2]
        data = [read_array_file(sample, dataset=dataset) for sample in samples] # read data for the samples 
        ofhs = [file("%s/arrays/qtl/%s/%s_diff/%s_qtls.tab"%(DATA_DIR, dataset, s, s1.replace("_plus","")), "w") for s in sets] # open QTL output file for each set
        seq_ofh = file("%s/qtl/%s_rec-0.90.tab"%(DATA_DIR, setname), 'w') # QTLs in sequence output format
        for ofh in ofhs: ofh.write("Sample\tQTL_set\tChrm\tPeak_loc\tPeak_signal\tPeak_width\tPeak_start\tPeak_end\tPeak_genes\n")
        seq_ofh.write("Set\tChrm\tPeak\tStart\tEnd\tLength\tAF_peak\tSD_peak\tnumSD_peak\n")
        LOG.debug("Getting QTLs for dataset %s sample %s"%(dataset, s1.replace("_plus","")))
        pickle_data = [{},{},{}]
                                                         
        for chrm in sorted(data[0][0]): # for each chromosome
            locs, L = data[0][0][chrm], len(data[0][0][chrm])
            signals = [infer_mean(locs, data[i][1][chrm], return_all=True) for i in range(2)] # the first two measures go with the first two QTL sets. Signals[i] = (mean, var, indicator whether probe is background)
            diff_sd = (signals[0][1] + signals[1][1])**0.5
            if save_signal:
                output_signal_file(dataset, samples, chrm, locs, [data[i][1][chrm] for i in range(2)], [data[i][2][chrm] for i in range(2)], signals)
                for i in range(2): pickle_data[i][NCBI_CHR_NAMES[chrm]] = (data[i][1][chrm], signals[i][0], SP.ones([L,2])*(0.125/signals[i][1] - 0.5), (signals[i][2][:,2] < 0.4), locs, SP.ones(L))
                pickle_data[-1][NCBI_CHR_NAMES[chrm]] = (locs, signals[0][0] - signals[1][0], SP.ones(len(signals[0][0]))*diff_sd)
            signals.append([signals[0][0] - signals[1][0]]) # the final one is the difference of the two signals. In a list of one for consistency with previous two, which are lists of signal, variance, and indicators

            for s in range(len(sets)): # for each QTL set
                qtl_peaks = call_peaks(signals[s][0], locs, diff_threshold, length_threshold) # call peaks for the set
                qtl_peaks = join_neighbours(qtl_peaks, max_dist=10000) # join peaks that are no more than 10 kb apart
                for (loc_p, sig_p, width_p, start_p, end_p, starti, endi, peaki) in qtl_peaks: # output all peaks
                    ofhs[s].write("%s\t%s\t%s\t%d\t%.2f\t%d\t%d\t%d\t"%(s1.replace("_plus",""), sets[s], chrm, loc_p, sig_p, width_p, start_p, end_p)) # metadata and location
                    ofhs[s].write(" ".join(get_peak_genes(signals[s][0][starti:(endi+1)], locs[starti:(endi+1)], geneloc[chrm], genename[chrm])) + "\n") # + list of genes in peak
                    if sets[s] == "plus_minus": seq_ofh.write("%s\t%s\t%d\t%d\t%d\t%d\t%.3f\t%.3f\t%.1f\n"%(setname, ROMAN_CHR_NAMES[int(chrm)], loc_p, start_p, end_p, width_p, sig_p, diff_sd, abs(sig_p)/diff_sd))

        for ofh in ofhs + [seq_ofh]: ofh.close()
        if save_signal:
            for i in range(2): cdm(pickle_data[i], "%s/arrays/inferred_af/pickles/%s.pickle"%(DATA_DIR, sids[i])) 
            cdm(pickle_data[-1],  "%s/seq/diffs/%s_rec-0.90.pickle"%(DATA_DIR, setname))


""" Call signal peaks that are at least length_threshold long and diff_threshold strong
@param signal length-L array of signal values
@param locs length-L array of probe locations
@param diff_threshold minimum absolute value of signal to be considered for peak
@param length_threshold minimum length peak to be considered
@return list of peaks [[peak_loc, peak_signal, peak_width, peak_start, peak_end, start_index, end_index, peak_index],...]"""
def call_peaks(signal, locs, diff_threshold, length_threshold):
    peaks = [] # list of peak locations
    start = -1 # start index of current peak
    
    for l in range(len(locs)):
        if abs(signal[l]) > diff_threshold: # if signal above threshold
            if start == -1: # if not in peak, new start and end are here
                start, end = l, l
            else: # if already in peak, extend the peak
                end = l
        else: # if signal not above threshold
            if start == -1: # and currently not in peak, just keep going
                pass
            else: # if peak ends
                if locs[end] - locs[start] > length_threshold: # if peak length is above threshold
                    peak = SP.argmax(abs(signal[start:(end+1)])) # get highest peak
                    peaks.append([locs[start + peak], signal[start + peak], locs[end] - locs[start], locs[start], locs[end], start, end, start + peak])
                start = -1 # (start < 0 <==> no longer in peak)

    return peaks


""" Join adjacent peaks that are at most max_dist apart
@param peaks list of QTL peaks
@require peaks ordered in chromosome order, nonoverlapping
@param max_dist minimum distance at which two peaks will be considered separate
@return list of QTL peaks
"""
def join_neighbours(peaks, max_dist=10000):
    if len(peaks) == 0: return peaks # no peaks = nothing to do

    START, END, LENGTH, SIGNAL, PEAK = 3, 4, 2, 1, 0
    res = [peaks[0]] # if peaks, start off with first one

    for i in range(1, len(peaks)): # go through rest of the peaks 
        if peaks[i][START] - res[-1][END] < max_dist: # if new peak start not more than max_dist away from previous peak end
            res[-1][END] = peaks[i][END] # join the right end
            res[-1][LENGTH] = res[-1][END] - res[-1][START] # update length

            if abs(peaks[i][SIGNAL]) > abs(res[-1][SIGNAL]): # if new signal stronger
                res[-1][PEAK], res[-1][SIGNAL] = peaks[i][PEAK], peaks[i][SIGNAL] # update the signal strength and location
        else:
            res.append(peaks[i]) # if not close enough to join peaks, just add new peak to list
    return res


""" Read chr->gene->loc mapping
@return tuple {chr:[[start, end]*]} and chr:[name*]}
"""
def get_gene_chr_hash():
    genes = read_gene_locs(use_common_name=True) # read gene->[chrm, start, end] mapping, using common gene name if possible
    locs, names = {}, {}
    if "YER053C-A" not in genes:pdb.set_trace()

    for g in genes: # for each gene
        chrm,start,end = genes[g] # get its location
        if str(chrm) not in locs: locs[str(chrm)], names[str(chrm)] = [], [] # if new chromosome, init values
        locs[str(chrm)].append([start, end]) # store location
        names[str(chrm)].append(g) # and name

    for c in locs: # for each chromosome
        locs[c] = SP.array(locs[c], int) # store locations of genes as integer array for easy manipulation
        names[c] = SP.array(names[c]) # and names as strings
    return locs, names


""" Get list of genes overlapping a QTL peak and within limit of the strongest signal
@param signal length-L array of signal values in peak
@param probe_locs length-L array of probe locations in peak
@param gene_locs length-K array of [[start,end]] locations of genes in the chromosome
@param gene_names length-K array of gene names; ordered as gene_locs
@param max_attenuation maximum weakening of the signal from the strongest
@return list of gene_names overlapping peak where signal is attenuated by no more than max_attenuation
"""
def get_peak_genes(signal, probe_locs, gene_locs, gene_names, max_attenuation=0.1):
    I = SP.where(abs(signal).max() - abs(signal) < max_attenuation)[0] # window around peak location with strong enough signal
    s,e = probe_locs[I[0]], probe_locs[I[-1]] # start and end are first and last matching location
    Ig = SP.where(((gene_locs[:,0] >= s) & (gene_locs[:,0] <= e)) | # filter genes such that either gene start is between peak start and peak end
                  ((gene_locs[:,1] >= s) & (gene_locs[:,1] <= e)) | # or gene end is between peak start and peak end
                  ((gene_locs[:,0] <= s) & (gene_locs[:,1] >= e)))[0] # or peak is within gene
    return list(gene_names[Ig]) # return a list of filtered names
    


""" Read QTLs from a tab-delimited file
@param dataset name of the dataset 
@param sample name of sample (e.g. A19)
@param qtl_set name of qtl set
@require qtl_set in ["plus_diploid", "minus_diploid", "plus_minus"]
@return []
"""
def read_qtls(sample, qtl_set="plus_minus", dataset="2013-02-22_round1"):
    ifh = file("%s/arrays/qtl/%s/%s_diff/%s_qtls.tab"%(DATA_DIR, dataset, qtl_set, sample), 'r')
    header = ifh.next()
    qtls = [cArrayQtl(l) for l in ifh]
    LOG.debug("Read %d QTLs for sample %s (qtl set %s, dataset %s)"%(len(qtls), sample, qtl_set, dataset))
    return qtls


""" Determine whether the QTL is a SGA one or frequent flyer. """
def is_common_qtl(qtl, allowed_delta=150000, return_gene_inclusion=False):
    commons = [("2", 200000, "HTA2"),("3",150000, "MATALPHA1"), ("5", 30000, "CAN1"), ("10",150000,"URA3")] # these correspond to HTA2, MATa, CAN1, and unknown reason
    for c in commons:
        if (qtl.chrm == c[0]) and abs(qtl.peak_loc - c[1]) < allowed_delta:
            if return_gene_inclusion: return c[2] in qtl.peak_genes
            return True
    return False


def is_known_qtl(qtl, allowed_delta=150000, return_gene_inclusion=False):
    knowns = [("12", 648000, "HAP1"), ("14", 468005, "MKT1"), ("15", 175000, "IRA2"), ("15", 1033000, "GPB1")]
    for c in knowns:
        if (qtl.chrm == c[0]) and abs(qtl.peak_loc - c[1]) < allowed_delta:
            if return_gene_inclusion: return c[2] in qtl.peak_genes
            return True
    return False
    


""" Determine whether QTL is of self linkage """
def is_self_qtl(qtl, allowed_delta=50000, return_gene_inclusion=False):
    self_qtls = { "A2": ("5", 260500, "YER053C-A"),
                  "A3": ("7", 960000, "YHB1"),
                  "A8": ("10", 339000, "TDH1"),
                  "A14": ("14", 52200, "ZIM17"),
                  "A18": ("1", 58000, "GCV3"),
                  "A19": ("4", 601000, "SED1"),
                  "A24": ("14", 64500, "RPL18B"),
                  "B8": ("13", 639000, "GCV2")}
    self_c, self_l, self_gene = self_qtls[qtl.sample]
    if return_gene_inclusion: return (self_gene in qtl.peak_genes)
    return (qtl.chrm == self_c and abs(self_l - qtl.peak_loc) < allowed_delta)


def includes_right_gene(qtl):
    if is_common_qtl(qtl): return is_common_qtl(qtl, return_gene_inclusion=True)
    elif is_self_qtl(qtl): return is_self_qtl(qtl, return_gene_inclusion=True)
    elif is_known_qtl(qtl): return is_known_qtl(qtl, return_gene_inclusion=True)
    return None

    

def get_orf_array_qtls(orf):
    samples = []
    for f in glob.glob("%s/arrays/samples/*.tab"%DATA_DIR):
        qtl_set = f.split("/")[-1].split(".")[0]
        for l in file(f, 'r'):
            d = l.strip().split("\t")
            if d[5].split(" ")[0] == orf:
                samples.append((qtl_set, d[3].split("_")[0]))
    return SP.unique(samples)
