#!/usr/bin/env python
##################################################
#  expression_correlation.py
#
#  @ /data/reddylab/projects/GGR/analyses/group_general/top_level_scripts/simple_linear_regression_expression_prediction.py
#
##################################################
import numpy as np
import pandas as pd
import argparse
from collections import defaultdict
from GGR import utils
import statsmodels.api as sm
from scipy.stats import laplace,pearsonr,norm

parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter, \
description="""
expression_correlation.py takes an response variable -y and one or a series of 
input predictors -Xs ordered by timepoint, with each timepoint corresponding
""")

##################################################
# required args:
parser.add_argument("-y", "--input_response", type=str,
                    help='required, numpy array of response variable', 
                    dest="y", action="store", required=True)
parser.add_argument("-Xs", "--input_predictors", type=str, nargs='+',
                    help='required, list of numpy arrays in natural order of time course', 
                    dest="Xs", required=True)
parser.add_argument("-w", '--weight_type', dest='weight_type', type=str,
                    help='required, <uniform|linear|laplace>', 
                    action='store', required=True)
parser.add_argument("-r", '--weights', dest='weights', nargs='+', type=int,
                    help='required to parameterize --weight_type, range is 1 to 100 for all weight types \
                    e.g. --weights $(echo {1..100})', 
                    required=True)
parser.add_argument("-o", "--out", dest='out', type=str, default='results.txt', action="store",
                    help='required, prefix to saved OLS models', required=True)

##################################################
# optional args:

parser.add_argument("-m", "--masks", dest='masks', nargs='+', type=str, default=None,
                    help='optional, list of numpy array of binary masks for predictors,\
                    e.g. DHSs or TFBSs may be used to limit the predictor matrices')
parser.add_argument('--comparison', 
                    help='if <log2|log2FC_t00|x_absmax> , log2 of X and y are taken after X is masked \
                    and summed by weights, if not indicated then matrices are taken as is', 
                    dest='comparison', action='store', type=str)
parser.add_argument("--WAD", 
                    help='if --comparison is log2FC_t00 then log2FC can be weighted \
                    using --WAD, which stands for weighted average difference wherein \
                    logFC is weighted by average log signal intensity as in PMID: 18578891', 
                    dest='WAD', action='store_true')
parser.add_argument("--logFC_y", 
                    help='if --logFC_y, then y is already in logFC form', 
                    dest='logFC_y', action='store_true')
parser.add_argument("--mutual_information", 
                    help='if --mutual_information, then compute MI instead of correlation', 
                    action='store_true')
parser.add_argument("--mutual_information_CI", 
                    help="""if --mutual_information_CI, then compute MI confidence intervals
                    using bootstrap resampling""", 
                    action='store_true')

args = parser.parse_args()

##################################################
# test block  

# args = type("args", (object,), {})()
# args.y = '/data/reddylab/projects/GGR/results/expression_prediction/iter0/expression/edgeR.minmax_logFC.protein_coding.npy'
# args.Xs = ['../binned_counts/EP300.in.edgeR.genes_tested.TSS.protein_coding.1000bp_bins_to_100000bp.t00.npy',
#           '../binned_counts/EP300.in.edgeR.genes_tested.TSS.protein_coding.1000bp_bins_to_100000bp.t05.npy',
#           '../binned_counts/EP300.in.edgeR.genes_tested.TSS.protein_coding.1000bp_bins_to_100000bp.t1.npy',
#           '../binned_counts/EP300.in.edgeR.genes_tested.TSS.protein_coding.1000bp_bins_to_100000bp.t2.npy',
#           '../binned_counts/EP300.in.edgeR.genes_tested.TSS.protein_coding.1000bp_bins_to_100000bp.t3.npy',
#           '../binned_counts/EP300.in.edgeR.genes_tested.TSS.protein_coding.1000bp_bins_to_100000bp.t4.npy',
#           '../binned_counts/EP300.in.edgeR.genes_tested.TSS.protein_coding.1000bp_bins_to_100000bp.t5.npy',
#           '../binned_counts/EP300.in.edgeR.genes_tested.TSS.protein_coding.1000bp_bins_to_100000bp.t6.npy',
#           '../binned_counts/EP300.in.edgeR.genes_tested.TSS.protein_coding.1000bp_bins_to_100000bp.t7.npy',
#           '../binned_counts/EP300.in.edgeR.genes_tested.TSS.protein_coding.1000bp_bins_to_100000bp.t8.npy',
#           '../binned_counts/EP300.in.edgeR.genes_tested.TSS.protein_coding.1000bp_bins_to_100000bp.t10.npy',
#           '../binned_counts/EP300.in.edgeR.genes_tested.TSS.protein_coding.1000bp_bins_to_100000bp.t12.npy']
# args.masks = None
# args.weight_type = 'uniform'        
# args.weights = range(1,100)  
# args.comparison = 'x_absmax'
# # args.out = 'EP300.expression.logFC.correlation.uniform.log2FC_t00.txt'
# args.correlation = True
# args.WAD=False
# args.logFC_y=True

##################################################
# import X and y

y = np.load(args.y)

# if y is of the form
# y_1,1 y_1,2, .. y1,t
# ...
# y_n,1 y_n,2, .. yn,t
# the following command will create
# [[y_1,1],[y_2,1],...,[y_n,t]]
y = np.vstack(y.flatten('F'))

# Typical input will be several numpy binary arrays (one for each timepoint)
# arrays should be listed in natural order of time course
Xs = np.array([])
Xs_list = []
for X in args.Xs:
    Xs_list.append(np.load(X))

# stack data from early to late in long array
X = np.concatenate(tuple(Xs_list))

# X is of dimension (#genes * number of timepoints) by (#bins, i.e. 100)

##################################################
# import masks and mask X
# Typical masks will be several numpy binary
# arrays of dimension  (#genes) by (#bins, i.e. 100)

# overall_mask = np.array([1])

if args.masks:
    # first create mask of all 1s
    overall_mask = np.ones_like(Xs_list[0])
    # successively multiply mask by overall mask
    # thus total mask is intersection not union
    for mask in args.masks:
        mask = np.load(mask)
        overall_mask = np.multiply(overall_mask, mask)
        
    # overall_mask is broadcast through all timepoints
    X = np.multiply(X, np.tile(overall_mask, (len(args.Xs),1)))
    # how much of sequence space is covered by mask?
    fraction_seq_space_mask = overall_mask.sum() / float(overall_mask.shape[0]*overall_mask.shape[1])
else:
    pass


##################################################
# define routines

def uniform_pdf(x, scale):
    # uniform pdf weights everything within window equal else 0
    if abs(x) > scale:
        return 0
    else:
        return 1/(scale * 2.0)

def linear_pdf(x, scale):
    # linear pdf decays like a triangle from the focus
    if abs(x) > scale:
        return 0
    else:
        return (scale - abs(x))/float(scale)

def laplace_pdf(x, scale):
    # laplace pdf decays like a double-exponential from the focus
    return laplace.pdf(x, loc=0, scale=scale)

pdf_function_dict = {}
pdf_function_dict['uniform'] = uniform_pdf
pdf_function_dict['linear'] = linear_pdf
pdf_function_dict['laplace'] = laplace_pdf

# if args.out_prefix:
#     args.out_prefix = args.out_prefix + '/' if args.out_prefix[-1] != '/' else args.out_prefix

##################################################

def compute_MI_from_c_xy(c_xy):
    from sklearn.metrics import mutual_info_score
    mi = mutual_info_score(None, None, contingency=c_xy)
    return mi

def compute_MI(x,y):
    c_xy = np.histogram2d(x, y, bins=100)[0]
    mi = compute_MI_from_c_xy(c_xy)
    return mi


##################################################
# for each weight compute OLS

results_dict = defaultdict(list)
for scale in args.weights:
    
    print 'Weight type = %s, scale = %s, transforming X and y...'%(args.weight_type, scale)
    
    #     weight_array = np.array([pdf_function_dict[args.weight_type](translate_col_to_bin(col), scale) for col in range(X.shape[1])])
    weight_array = np.array([pdf_function_dict[args.weight_type](col,scale) for col in range(X.shape[1]/2)])
    weight_array = np.concatenate((weight_array[::-1], weight_array))
    weight_array = weight_array / weight_array.sum()
    #     X_weighted = np.multiply(X, weight_array)
    #     X_weighted_summed = X_weighted.sum(axis=1)
    X_weighted = np.multiply(X, weight_array)
    X_weighted_summed = X_weighted.sum(axis=1)
    
    if args.comparison == 'log2':
        
        X_weighted_summed_log2 = np.log2(X_weighted_summed + 1)
        X_weighted_summed_log2 = np.nan_to_num(X_weighted_summed_log2)
        X_for_corr = X_weighted_summed_log2.copy()
        y_for_corr = np.log2(y.flatten() + 1)
        
    elif args.comparison == 'log2FC_t00':
        
        # take log2
        X_weighted_summed_log2 = np.log2(X_weighted_summed + 1)
        X_weighted_summed_log2 = np.nan_to_num(X_weighted_summed_log2)
        # how many rows are there per timepoint? (generally = # genes)
        num_rows_per_timepoints = X_weighted_summed.shape[0] / len(args.Xs)
        # isolate t00
        X_weighted_summed_log2_t00 = X_weighted_summed_log2[:num_rows_per_timepoints]
        
        # isolate NOT t00
        X_weighted_summed_log2_not_t00 = X_weighted_summed_log2[num_rows_per_timepoints:]
        # compute log2FC from t00
        X_weighted_summed_log2FC_from_t00 = X_weighted_summed_log2_not_t00 - np.tile(X_weighted_summed_log2_t00, len(args.Xs)-1)
        
        if not args.logFC_y:
            y_log2_t00 = y_log2[:num_rows_per_timepoints]
            y_log2_not_t00 = y_log2[num_rows_per_timepoints:]
            y_log2FC_from_t00 = y_log2_not_t00.flatten() - np.tile(y_log2_t00.flatten(), len(args.Xs)-1)
        else:
            y_log2FC_from_t00 = y
        
        if args.WAD:
            
            # compute mean log2 for each row (gene) across all timepoints in response and predictor
            X_mean_log2 = X_weighted_summed_log2.reshape((num_rows_per_timepoints, len(args.Xs))).mean(axis=1)
            y_mean_log2 = y_log2.reshape((num_rows_per_timepoints, len(args.Xs))).mean(axis=1)
            # compute WAD weights according to equation in PMID: 18578891
            X_WAD_weights = (X_mean_log2 - X_mean_log2.min()) / (X_mean_log2.max() - X_mean_log2.min())
            y_WAD_weights = (y_mean_log2 - y_mean_log2.min()) / (y_mean_log2.max() - y_mean_log2.min())
            # weight log2FC from t00 by WAD weights
            X_weighted_summed_log2FC_from_t00 = np.multiply(X_weighted_summed_log2FC_from_t00, np.tile(X_WAD_weights, len(args.Xs)-1))
            if not args.logFC_y:
                y_log2FC_from_t00 = np.multiply(y_log2FC_from_t00, np.tile(y_WAD_weights, len(args.Xs)-1))
        
        X_for_corr = X_weighted_summed_log2FC_from_t00.copy()
        y_for_corr = y_log2FC_from_t00.flatten().copy()
        
    elif args.comparison == 'x_absmax':
        # take log2
        X_weighted_summed_log2 = np.log2(X_weighted_summed + 1)
        X_weighted_summed_log2 = np.nan_to_num(X_weighted_summed_log2)
        # how many rows are there per timepoint? (generally = # genes)
        num_rows_per_timepoints = X_weighted_summed.shape[0] / len(args.Xs)
        # isolate t00
        X_weighted_summed_log2_t00 = X_weighted_summed_log2[:num_rows_per_timepoints]
        
        # isolate NOT t00
        X_weighted_summed_log2_not_t00 = X_weighted_summed_log2[num_rows_per_timepoints:]
        # compute log2FC from t00
        X_weighted_summed_log2FC_from_t00 = X_weighted_summed_log2_not_t00 - np.tile(X_weighted_summed_log2_t00, len(args.Xs)-1)
        # reshape for cols = t05 to t12
        X_weighted_summed_log2FC_from_t00 = X_weighted_summed_log2FC_from_t00.reshape(num_rows_per_timepoints,  len(args.Xs)-1, order='F')
                
        # take absolute max (first find argmax of abs, then extract value)
        max_indices = np.abs(X_weighted_summed_log2FC_from_t00).argmax(axis=1)
        X_weighted_summed_log2FC_from_t00 = np.array([X_weighted_summed_log2FC_from_t00[i,j] for i,j in enumerate(max_indices)])
        
        X_for_corr = X_weighted_summed_log2FC_from_t00.copy()
        # for x_absmax, y is already assumed to be in logFC format
        y_for_corr = y.flatten().copy()
        
    else:
        X_for_corr = X_weighted_summed.copy()
        y_for_corr = y.copy()
        
    if args.mutual_information:
        print 'Weight type = %s, scale = %s, computing MI...'%(args.weight_type, scale)
        mi = compute_MI(X_for_corr, y_for_corr)
        print 'MI = %0.3e'%(mi)
        results_dict['MI'].append(mi)
        if args.mutual_information_CI:
            c_xy = np.histogram2d(X_for_corr, y_for_corr, bins=100)[0]
            nperms=1000; mi_perms=[]
            for perm in range(nperms):
                idx_perm = np.random.choice(len(c_xy), size=len(c_xy), replace=True)
                mi_perms.append(compute_MI_from_c_xy(c_xy[idx_perm,:]))
            
#             import scikits.bootstrap
#             idx = np.arange(len(X_for_corr))
#             nperms=1000
#             mi_perms = []
#             for perm in range(nperms):
#                 idx_perm = np.random.choice(idx, size=len(idx), replace=True)
#                 mi_perms.append(compute_MI(X_for_corr[idx_perm], y_for_corr[idx_perm]))
            
            MI_low = sorted(mi_perms)[int(len(mi_perms) * 0.025)]
            MI_high = sorted(mi_perms)[int(len(mi_perms) * 0.975)]
            results_dict['MI_low'].append(MI_low); results_dict['MI_high'].append(MI_high)
    else:
        print 'Weight type = %s, scale = %s, computing correlation and CI of R...'%(args.weight_type, scale)
        r,p = pearsonr(X_for_corr, y_for_corr)
        # Fishers's r-to-z transform
        z = np.arctanh(r)
        sigma = (1/((len(X_for_corr)-3)**0.5))
        cint = z + np.array([-1, 1]) * sigma * norm.ppf((1+0.95)/2)
        r_low, r_high = np.tanh(cint)
        print 'r = %0.3e, r low = %0.3e, r high = %0.3e'%(r, r_low, r_high)
        results_dict['r'].append(r); results_dict['r_low'].append(r_low); results_dict['r_high'].append(r_high)
    results_dict['weight_type'].append(args.weight_type)
    results_dict['scale'].append(scale)
    if args.WAD:
        results_dict['WAD'].append(1)
    else:
        results_dict['WAD'].append(0)
    
    if args.masks:
        results_dict['fraction_of_sequence_space'].append(fraction_seq_space_mask)

results = pd.DataFrame(results_dict)

output_cols = ['weight_type','scale','WAD']
if args.mutual_information:
    output_cols += ['MI']
else:
    output_cols += ['r', 'r_low','r_high']

if args.mutual_information_CI:
    output_cols += ['MI_low', 'MI_high']

if 'fraction_of_sequence_space' in results_dict:
    output_cols += ['fraction_of_sequence_space']

results[output_cols].to_csv(args.out, sep='\t', index=False, header=True)

# if args.WAD:
#     est.save('%s_%s_%s_%s_WAD_OLS_model.p'%(args.out_prefix, args.weight_type, args.comparison, scale))
# else:
#     est.save('%s_%s_%s_%s_OLS_model.p'%(args.out_prefix, args.weight_type, args.comparison, scale))
