import pylab as PL
import pdb
import glob
from parts2014_gfpvar.tools.common import *
from parts2014_gfpvar.analysis.array.util import *
from parts2014_gfpvar.analysis.array.infer import *


""" Plot per-chromosome raw signal for GFP-high (plus) and GFP-low (minus) samples, and the difference between the two """
def plot_array_data(dataset="2013-02-22_round1", sample_pattern="*_plus", plot_window=570000, save=True, threshold=0.5, chrs=None, cols="rb"):

    # for each sample matching the given pattern
    for f in glob.glob("%s/arrays/condensed/%s/*%s*.tab"%(DATA_DIR, dataset, sample_pattern)):
        s1 = f.split("/")[-1][0:-4] # get the + and - samples
        samples = [s1, s1.replace("plus", "minus")]
        data = [read_array_file(sample, dataset=dataset) for sample in samples] # read data for the samples                                                         
        if chrs is None: chrs = sorted(data[0][0]) # for each chromosome,
            
        for chrm in chrs:
            if save and os.path.exists("%s/arrays/plots/%s/combined_%s_chr%s_%d.png"%(DATA_DIR, dataset, samples[0].split("_")[0], chrm, 0)): # skip already processed plots
                continue
            LOG.debug("Plotting chromosome %s of sample pair %s,%s in %d kb windows"%(chrm, samples[0], samples[1], plot_window/1000))
            locs = data[0][0][chrm]
            m1, sd1, z1 = infer_mean(locs, data[0][1][chrm], threshold=threshold, return_all=True) # mu is length-L, sd is float, z is length-Lx3
            m2, sd2, z2 = infer_mean(locs, data[1][1][chrm], threshold=threshold, return_all=True)
            mu, sd, z = [m1,m2], [sd1,sd2], [z1,z2]
            
            for plot_i in range(1 + int(max(locs/plot_window))): # for each plot required to be made. This splits up per-chromosome plots into chunks to avoid huge (>600kb) ones.
                PL.figure(figsize=(17,9))
                x0, x1 = plot_window*plot_i-20000, plot_window*(plot_i+1) + 20000 # chunks of window size kb
                PL.subplot(211)
                plot_sample_signal(locs, data, mu, z, x0, x1, threshold, chrm, cols=cols) # plot raw signal for samples
                PL.subplot(212)
                plot_signal_diff(locs, data, mu, sd, x0, x1, chrm) # and the difference
                PL.title("Sample %s; Chromosome %s; %d-%dkb; top=BY-like"%(s1,chrm, plot_i*plot_window/1000, (plot_i+1)*plot_window/1000))

                if save:
                    os.system("mkdir -p %s/arrays/plots/%s"%(DATA_DIR, dataset))
                    PL.savefig("%s/arrays/plots/%s/combined_%s_chr%s_%d.png"%(DATA_DIR, dataset, samples[0].split("_")[0], chrm, plot_i), dpi=300)
                else:
                    PL.show()

""" Plot raw data as markers and average as a solid line """
def plot_sample_signal(locs, data, mu, z, x0, x1, threshold, chrm, cols="rb"):
    PL.plot([min(locs), max(locs)], [0,0], 'k--', linewidth=2, alpha=0.5)
    alphas = [0.1,0.2,0.4]
    markersizes = [12,12,15]
    for sample in range(2):
        for cluster in range(3):
            Ics = SP.where(z[sample][:,cluster] > threshold)[0]
            PL.plot(locs[Ics], data[sample][1][chrm][Ics], cols[sample] + ".", markersize=markersizes[cluster], alpha=alphas[cluster])
            PL.plot(locs, mu[sample], cols[sample] + "-", linewidth=3)
    PL.xlim(x0, x1)
    PL.ylim(-1,1)
    if sum((abs(data[0][1][chrm]) > 0.5) + (abs(data[1][1][chrm]) > 0.5)) < 20:
        PL.ylim(-0.5,0.5)
    PL.ylabel("log(BY signal) - log(RM signal)")


""" Plot difference in means, and confidence interval in it """
def plot_signal_diff(locs, data, mu, sds, x0, x1, chrm):
    m, sd = mu[0] - mu[1], (sds[0] + sds[1])**0.5
    PL.plot(locs, m, 'g-', linewidth=4)
    x_ci = SP.array(list(locs) + list(locs)[::-1])
    y_ci = SP.array(list(m) + list(m)[::-1])
    sds = SP.array([sd]*len(m) + [-sd]*len(m))
    PL.fill(x_ci, y_ci + sds, 'g', alpha=0.2)
    PL.fill(x_ci, y_ci + 2*sds, 'g', alpha=0.2) 
    PL.plot([min(locs), max(locs)],[0,0], 'k-', linewidth=3, alpha=0.2)
    PL.ylabel("signal(high GFP) - signal(low GFP)")
    PL.xlim(x0, x1)
    PL.ylim(-1,1)
    if sum((abs(data[0][1][chrm]) > 0.5) + (abs(data[1][1][chrm]) > 0.5)) < 20:
        PL.ylim(-0.5,0.5)
        

""" Plot the signal comparison across two replicate datasets
@param sample sample ID (e.g. A14_plus)
@param datasets list of two datasets to compare (e.g. ["2013-03-07_round1-hyb3-1","2013-03-07_round1-hyb3-2"])
@require len(datasets) >= 2
@param chrm chromosome to plot e.g. "2"
@param lim y-limit
@param alpha parameter for plotting, lower=more transparent"""
def plot_comparison(sample, datasets, chrm, lim=(-0.5,0.5), alpha=0.3):
    data = [read_array_file(sample,datasets[i])[1] for i in range(len(datasets))] # read data
    
    PL.figure(figsize=(10,10)) # plot
    PL.plot(data[0][chrm], data[1][chrm], ".", markersize=15, alpha=alpha)
    x0,x1,y0,y1 = -0.1, 0.1, -0.1, 0.1
    PL.plot([x0,x1,x1,x0,x0], [y0,y0,y1,y1,y0], 'r-')
    
    I = SP.where((abs(data[0][chrm]) < 0.1) & (abs(data[1][chrm]) < 0.1))[0] # title and label
    sd = SP.mean([data[i][chrm][I].var() for i in range(len(data))])**0.5
    PL.title("%s - chr %s; r^2=%.2f; noise sd=%.3f"%(sample, chrm, SP.corrcoef(data[0][chrm], data[1][chrm])[0,1], sd))
    PL.xlim(lim), PL.ylim(lim)
    PL.xlabel("%s"%(datasets[0]))
    PL.ylabel("%s"%(datasets[1]))
    PL.show()

