from parts2014_gfpvar.tools.io import *
import glob
import pdb

# Cytometry convenience functions:

""" Take filename (e.g /usr/.../Specimen_001_E7_E07.fcs), get 'Plate-P11C_H7' """
def get_well96_from_filename(f):
    plate = f.split("/")[-2] # P10AB
    subs_exist = ((plate.count("AB") > 0) or (plate.count("CD") > 0)) and (plate.count("Candidate") == 0)
    plateno, platesubs = [[plate, None],[plate[0:-2], plate[-2:]]][subs_exist]
    
    well = f.split("/")[-1].split("_")[-2] # N9
    row = ord(well[0]) - ord('A') # row is now 0-based
    col = int(well[1:]) # col is still 1-based
    if subs_exist:   return "Plate-%s%s_%s%d"%(plateno, platesubs[row%2], chr(ord('A') + row/2), (col+1)/2)
    platesub = "ABCD"[2*(row%2) + ((col+1)%2)]
    return "Plate-%s%s_%s%d"%(plateno, platesub, chr(ord('A') + row/2), (col+1)/2)


""" Plate-P11C_H7 -> ('P11', 'P', 13) (mixed_sample);  ('P11CD', 'O', 13) otherwise """
def get_screen_well(plate_well, mixed_sample=True):
    plate = plate_well.split("-")[1].split("_")[0] # 8C
    platename = plate[0:-1] # 8
    if not mixed_sample:
        platename = platename + ["AB", "CD"][plate.count("C") + plate.count("D")] # 8AB
    w = plate_well.split("_")[-1] # D9
    row = chr(ord('A') + 2*(ord(w[0]) - ord('A')) + plate.count("B") + plate.count("D"))
    if mixed_sample:
        row = chr(ord('A') + 2*(ord(w[0]) - ord('A')) + plate.count("C") + plate.count("D"))
    colstart = 2*int(w[1:]) - 1
    if mixed_sample:
        colstart += (plate.count("B") + plate.count("D"))
    return platename, row, colstart


""" 'haploid-R1' + "Plate-P11C_H7" -> ../haploid-R1/P11CD/Specimen_001_O13_O13.fcs
'haploid-R2' + "Plate-P11C_H7" -> ../haploid-R2/P11/Specimen_001_P13_P13.fcs
"""
def get_plate96str_filenames(set, platestr96, return_both=True, external_hd=False):
    dir_base = "%s/cytometry/Pilot_screen_BYxRM/%s"%(DATA_DIR, set)
    if external_hd: dir_base = "/Volumes/LaCie/Pilot_screen_BYxRM/%s"%(set)
    file_base = "Specimen_001_%s%d_%s%s.fcs"
    plate96 = platestr96.split("-")[1].split("_")[0][1:]
    rowcol = platestr96.split("_")[1]
    plate = plate96[0:-1]
    if set == "haploid-R1":  plate = plate + ["AB","CD"][plate96.count("C") + plate96.count("D")]
    tmp, row, col = get_screen_well(platestr96, mixed_sample=(set != "haploid-R1"))
    files = []
    cols = [col]
    if set == "haploid-R1":
        if return_both:  cols.append(2*(col/2) + 1 + (col%2)) # return two adjacent filenames
        else: cols = [2*((col-1)/2) + 1] # just return the first of the two
    for col in cols:
        col2 = str(col)
        if len(col2) == 1: col2 = "0" + col2
        files.append("%s/P%s/%s"%(dir_base, plate, file_base%(row, col, row, col2)))
    return files
    

""" Plate-P11C_H7 -> 11, 16, 13 """     
def get_array_well(plate_well):
    plate = plate_well.split("-")[1].split("_")[0]
    plate_i, plate_sub = int(plate[1:-1]), plate[-1]
    w = plate_well.split("_")[-1]
    row = 2*(ord(w[0]) - ord('A')) + (plate_sub in "CD") + 1
    col = 2*int(w[1:]) - 1 + (plate_sub in "BD")
    return plate_i, row, col
    

""" 11, 16, 13 -> 11, C, H, 7  . Note all indexes 1-based """     
def get_96well(plate384, row384, col384):
    plate_sub = "ABCD"[2*((row384+1)%2) + (col384+1)%2]
    return plate384, plate_sub, chr(ord('A') + (row384-1)/2), (col384+1)/2



def get_feature_from_96wellstr(wells, feature=0, skip_nan=True): # 0 = SGD name, 1=common name, 2=gfp level, 3=localisation
    features = [] 
    array_data = get_array_data(skip_nan=skip_nan)

    for w in wells:
        plate, row, col = get_array_well(w)
        features.append(array_data[plate,row,col,feature])
    return features



def get_gfps(wells, log=True, return_namehash=False):
    d = get_array_data()
    gfps = -SP.ones(len(wells))
    names = {}
    
    for i in range(len(wells)):
        plate, row, col = get_array_well(wells[i])
        gfps[i] = float(d[plate,row,col,2])
        if log: gfps[i] = SP.log2(gfps[i])
        names[wells[i]] = " ".join(map(str, d[plate,row,col,0:3]))

    if return_namehash: return gfps, names
    return gfps


# return hash "YCR001" -> (2,12,1) (plate, row, column, 384 format, all indexes 1-based)
def get_gene_locations():
    d = get_array_data(skip_nan=False)
    locs = {}

    for plate in range(1, d.shape[0]):
        for row in range(1, d.shape[1]):
            for col in range(1, d.shape[2]):
                locs[d[plate,row,col][0]] = plate,row,col

    return locs


def get_glucose_smith_targets(largeonly=False, all_sep_distant=False, all_peaks=False, all_trans=False):
    d = SP.loadtxt("%s/ref/trans_peaks.txt"%DATA_DIR, delimiter="\t", dtype=object, skiprows=1)
    res = {}
    for i in range(len(d)):
        if (d[i,0] not in res) and (float(d[i,1]) > 6): # if not there, and large enough LOD score
            if all_trans and d[i,2][3:].count("local") == 0:
                res[d[i,0]] = float(d[i,1]),d[i,2]
            if d[i,2][0:3] == "glu":  # if glucose peak
                if largeonly and (int(d[i,2][3:] > 8)): # if peak number 9+
                    res[d[i,0]] = float(d[i,1])
                if all_peaks: # if any peak
                    res[d[i,0]] = float(d[i,1])
            if all_sep_distant and (d[i,2] == "not in distant peak"):
                res[d[i,0]] = float(d[i,1]) # if any trans QTL
    return res



def get_glucose_smith_target_orfs(peak="HAP1", lod_cutoff=4.):
    peak_name = {"HAP1":"glu9", "IRA2":"glu12", "MKT1":"glu11", "AMN1":"glu3", "IRA1":"glu2"}[peak]
    d = SP.loadtxt("%s/ref/trans_peaks.txt"%DATA_DIR, delimiter="\t", dtype=object, skiprows=1)
    res = []
    for orf,lod,orf_peak_name in d:
        if (orf_peak_name == peak_name) and (float(lod) > lod_cutoff):
            res.append(orf)
    return res


def get_glucose_smith_qtls_peakavoid(avoid_peaks=["HAP1", "MKT1", "HAP1"], lod_cutoff=4.): 
    peak_names = {"HAP1":"glu9", "IRA2":"glu12", "MKT1":"glu11"}
    avoid_names = [peak_names[p] for p in avoid_peaks]
    d = SP.loadtxt("%s/mrna/smith_linkage_results_thinned_all.tab"%DATA_DIR, delimiter="\t", dtype=object, skiprows=1)
    
    res = {}
    for orf,linkage_type,chrm,lod,peak_name, marker, ef1,ef2 in d:
        if peak_name in avoid_names: continue
        if float(lod) < lod_cutoff: continue
        if orf not in res: res[orf] = []
        res[orf].append([int(chrm),int(marker), float(lod)])
    return res



def print_candidates():
    d = get_array_data(skip_nan=False)
    vals = SP.loadtxt("/Users/leopold/doc/projects/gfpvar/2012-05-28_sorting_round1.txt", dtype=int, skiprows=1)
    for i in range(vals.shape[0]):
        #print "Plate %d, row %d, column %d"%(tuple(vals[i])), "\t", d[vals[i][0]][vals[i][1]][vals[i][2]]
        print d[vals[i][0]][vals[i][1]][vals[i][2]][0]


def print_sorted(return_values=False):
    genes = get_array_data(skip_nan=False)
    #    rnas = get_data()
    ifh = file("/Users/leopold/doc/projects/gfpvar/tmp_sorted_genes.txt", "r")
    values = []
    
    for l in ifh:
        d = l.strip().split()
        p,r,c = 0,0,0
        if len(d) == 3: # plateX, plate rowcol, e.g 11C A2
            p,r,c = get_array_well("Plate-P%s_%s"%(d[1],d[2]))
        else: # index, plate, row, col, e.g. 11 1 2
            p,r,c = map(int, d[1:])
        print "%s\t%d\t%d\t%d\t"%(d[0], p,r,c), genes[p,r,c]#, rnas[genes[p,r,c][0]]
        values.append([d[0], p,r,c, genes[p,r,c]])
    if return_values:
        return values
        


def get_set_overlap_index(sets):
    sets = map(list, sets)
    overlap = set(sets[0])
    for s in sets:  overlap = overlap & set(s) # all items in set
    overlap = list(overlap)
    result = [] # indexes of items in overlap in individual sets
    for i in range(len(sets)):
        result.append([sets[i].index(overlap[j]) for j in range(len(overlap)) if overlap[j] in sets[i]])
    return result



def read_halflife():
    res = {}
    ifh = file("%s/ref/halflife_pnas_0605420103.txt"%DATA_DIR, 'r')
    for l in ifh:
        if l.count("Y") == 0: continue
        orf, common, raw, corrected = l.strip("\n").strip("\r").split("\t")[0:4]
        if corrected == "300": corrected = SP.nan
        res[orf] = float(corrected)
    return res
