#!/usr/bin/env python
import pandas as pd
import numpy as np
import argparse

parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter, \
description="""

compare_continuous_test_lists_to_bkg_lists.py

Describe.

""")

##################################################
# required args

parser.add_argument("--test_lists", nargs="+", help="required, test lists of values",  required=True)
parser.add_argument("--test_names", nargs="+", help="required, names of test lists",  required=True)
parser.add_argument("--bkg_lists", nargs="+", help="required, background lists of values",  required=True)
parser.add_argument("--bkg_names", nargs="+", help="required, names of background lists",  required=True)

parser.add_argument("--out_df", help="required, output dataframe", required=True)


##################################################
# optional args

parser.add_argument("--test_type", help="optional, two-part test for zero-inflation", default="mann_whitney")

##################################################
args = parser.parse_args()

if args.test_type == "mann_whitney":
    from scipy.stats import mannwhitneyu
elif args.test_type == "ttest_rel":
    from scipy.stats import ttest_rel
elif args.test_type == "ttest_ind":
    from scipy.stats import ttest_rel

def lachenbruch(arr1, arr2):
    # two-part test from "Analysis of data with excess zeros", Lachenbruch 2002
    from scipy.stats import chi2
    # compute statistic zero values
    n1 = float(len(arr1))
    n2 = float(len(arr2))
    m1 = len(arr1[arr1 == 0])
    m2 = len(arr2[arr2 == 0])
    p1 = m1 / n1
    p2 = m2 / n2
    p_pooled = (m1 + m2) / (n1 + n2)
    B2 = (p1 - p2)**2 / ( (p_pooled * (1 - p_pooled)) * ((n1 + n2) / (n1*n2)) )
    # B2 not well-defined if only zero values in both arrays or no zero values in both arrays
    B2 = B2 if ( p1 > 0 or p2 > 0) and ( p1 < 1 or p2 < 1) else 0
    
    # compute statistic non-zero values
    U,_ = mannwhitneyu(arr1[arr1 > 0], arr2[arr2 > 0])
    m_U = (n1 - m1) * (n2 - m2) / 2.
    sigma_U = np.sqrt((n1 - m1) * (n2 - m2) * (n1 - m1 + n2 - m2 + 1) / 12.)
    W = (U - m_U) / sigma_U
    
    # synthesize and test
    X2 = B2 + W**2
    p = 1 - chi2.cdf(X2, df=2)
    return(p)

test_lists = {}
for name,l in zip(args.test_names, args.test_lists):
    with open(l, "r") as f:
        test_lists[name] = np.array([float(line.strip()) for line in f if line.strip() != ""])

bkg_lists = {}
for name,l in zip(args.bkg_names, args.bkg_lists):
    with open(l, "r") as f:
        bkg_lists[name] = np.array([float(line.strip()) for line in f if line.strip() != ""])

res_df = pd.DataFrame(index=sorted(test_lists) + sorted(bkg_lists))

lower_quartiles = ["%0.2f"%np.percentile(test_lists[name], 25) for name in sorted(test_lists)]
medians = ["%0.2f"%np.median(test_lists[name]) for name in sorted(test_lists)]
means = ["%0.2f"%np.mean(test_lists[name]) for name in sorted(test_lists)]
upper_quartiles = ["%0.2f"%np.percentile(test_lists[name], 75) for name in sorted(test_lists)]

lower_quartiles += ["%0.2f"%np.percentile(bkg_lists[name], 25) for name in sorted(bkg_lists)]
medians += ["%0.2f"%np.median(bkg_lists[name]) for name in sorted(bkg_lists)]
means += ["%0.2f"%np.mean(bkg_lists[name]) for name in sorted(bkg_lists)]
upper_quartiles += ["%0.2f"%np.percentile(bkg_lists[name], 75) for name in sorted(bkg_lists)]

res_df['lower_quartile'] = lower_quartiles
res_df['mean'] = means
res_df['median'] = medians
res_df['upper_quartile'] = upper_quartiles

for bkg_name in sorted(bkg_lists):
    bkg_l = bkg_lists[bkg_name]
    p_vals = []
    for test_name in sorted(test_lists):
        test_l = test_lists[test_name]
        if args.test_type == "mann_whitney":
            U,p = mannwhitneyu(test_l, bkg_l)
        elif args.test_type == "ttest_rel":
            t,p = ttest_rel(test_l, bkg_l)
        elif args.test_type == "ttest_ind":
            t,p = ttest_ind(test_l, bkg_l)
        elif args.test_type == "lachenbruch":
            p = lachenbruch(test_l, bkg_l)
        else:
            pass
        p_vals.append("%0.2e"%p)
    
    res_df['p-value vs. %s'%(bkg_name)] = p_vals + ["NA"] * len(bkg_lists)
    
res_df.to_csv(args.out_df, sep="\t", index=True)