#!/usr/bin/env python
import pandas as pd
import argparse

parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter, \
description="""

This script computes a correlation matrix of all pairwise correlation coefficients
from a whitespace-delimited dataframe. If input is N x P, correlation
matrix is P x P by default, unless --transpose indicated, in which case output
becomes N x N.

""")

optional = parser._action_groups.pop()
required = parser.add_argument_group('required arguments')

##################################################

required.add_argument("--df", help="required, path to input dataframe", required=True)
required.add_argument("--corr", help="required, path to output correlation matrix", required=True)

##################################################
# optional args:

optional.add_argument("--standardize", help="optional, if indicated, standardize each column (feature)", action='store_true')
optional.add_argument("--norm", help="optional, if indicated, normalize all columns to sum to same value", action='store_true')
optional.add_argument("--no_header", help="optional, if indicated, dataframe does not have header", action='store_true')
optional.add_argument("--transpose", help="""optional, if indicated, transpose compute correlations
across observations instead of across features""", action='store_true')

##################################################
parser._action_groups.append(optional)
args = parser.parse_args()

header = 'infer' if not args.no_header else None
df = pd.read_csv(args.df, sep='\t', index_col=0, header=header)

if args.transpose:
    df = df.transpose()
    
if args.standardize:
    # subtract mean and divide by std dev.
    df = ( df - df.mean(axis=0) ) / df.std(axis=0)

if args.norm:
    import numpy as np
    # compute normalization factors
    norm_factors = (np.array(df.sum()) / float(df.sum().min()))**(-1)
    # normalize
    df = pd.DataFrame(np.array(df) * norm_factors, index=list(df.index), columns=list(df.columns))
    
# compute correlation
corr = df.corr()

if not args.no_header:
    corr.to_csv(args.corr, sep='\t')
else:
    header = True if not args.no_header else False
    corr.to_csv(args.corr, sep='\t', index=False, header=header)