#!/usr/bin/env python
##################################################
#  binned_count_weighting.py
#
#
##################################################
import numpy as np
import argparse
from scipy.stats import laplace

parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter, \
description="""
<describe binned_count_weighting.py>

""")

##################################################
# required args:
parser.add_argument("-X", "--X",
                    help='required, numpy array', required=True)
parser.add_argument("-w", '--weight_type',
                    help='required, <uniform|linear|laplace>', required=True)
parser.add_argument("-r", '--weights', type=int, nargs="+",
                    help='required to parameterize --weight_type, range is 1 to 100 for all weight types \
                    e.g. --weight 20', required=True)
parser.add_argument("--out",
                    help='required, numpy array to which summarized binned epigenetic data should be saved', 
                    required=True)

##################################################
# optional args:
parser.add_argument("-m", "--masks", nargs='+', type=str,
                    help='optional, list of numpy array of binary masks for predictors,\
                    e.g. DHSs or TFBSs may be used to limit the predictor matrices')
parser.add_argument("--exclude_proximal_bins", type=int,
                    help="""optional, exclude the first (--exclude_proximal_bins) X proximal
                    bins. Useful to consider only distal regulatory regions by excluding, e.g.,
                    the nearest 3 kb (--exclude_proximal_bins 3)""")

args = parser.parse_args()

##################################################
# test block  

# args = type("args", (object,), {})()
# # args.y = 'TPM.protein_coding.npy'
# args.X = 'DNase.t00.in.protein_coding.filtered_by_expression.1000bp_bins_to_100000bp.npy'

# # args.masks = ["/data/reddylab/projects/GGR/data/expression_prediction/iter0/TSS.flank_windows.intersect.DHSs.npy"]
# args.masks = []
# args.weight_type = 'uniform'        
# args.weights = range(1,101)  
# args.comparison = None
# args.exclude_proximal_bins = 3

##################################################

X = np.load(args.X)

##################################################
# import masks and mask X
# Typical masks will be several numpy binary arrays of dimension  (#genes) by (#bins, i.e. 100)

overall_mask = np.array([1])

if args.masks:
    overall_mask = np.ones_like(X)
    for mask in args.masks:
        mask = np.load(mask)
        overall_mask = np.multiply(overall_mask, mask)
        
    X = np.multiply(X, overall_mask)
    # how much of sequence space is covered by mask?
    fraction_seq_space_mask = overall_mask.sum() / float(overall_mask.shape[0]*overall_mask.shape[1])
else:
    pass

##################################################
# define routines

def uniform_pdf(x, scale):
    if abs(x) > scale:
        return 0
    else:
        return 1/(scale * 2.0)

def linear_pdf(x, scale):
    if abs(x) > scale:
        return 0
    else:
        return (scale - abs(x))/float(scale)

def laplace_pdf(x, scale):
    return laplace.pdf(x, loc=0, scale=scale)

pdf_function_dict = {}
pdf_function_dict['uniform'] = uniform_pdf
pdf_function_dict['linear'] = linear_pdf
pdf_function_dict['laplace'] = laplace_pdf

X_weighted_list = []
for scale in args.weights:
    weight_array = np.array([pdf_function_dict[args.weight_type](col,scale) for col in range(1,X.shape[1]/2+1)])
    if args.exclude_proximal_bins:
        # exclude X proximal bins from weight_array
        weight_array[0:args.exclude_proximal_bins] = 0
        
    weight_array = np.r_[weight_array[::-1],weight_array]
    weight_array = weight_array / weight_array.sum()
    X_weighted = np.multiply(X, weight_array)
    X_weighted_list.append(X_weighted.sum(axis=1))

X_weighted_all = np.array(X_weighted_list)
X_weighted_all = X_weighted_all.T
np.save(args.out, X_weighted_all)