#!/usr/bin/env python
import pandas as pd
import numpy as np
from collections import Counter
import argparse

parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter, \
description="""

In each of a list of --input files, compute the proportion of elements
that are of each of all categorical values that appear in the file.
Create an output dataframe with each column corresponding to a name
in --names

""")

optional = parser._action_groups.pop()
required = parser.add_argument_group('required arguments')

##################################################
# required args:

required.add_argument("-i", "--input", nargs='+', type=str,
                    help="""required, file paths to lists of categorical variables:

/path/to/list1.txt /path/to/list2.txt
which are of the following headerless format:

Distal Intergenic
Distal Intergenic
Promoter (<=1kb)
Promoter (1-2kb)
Intron
Intron
Promoter (<=1kb)
Promoter (<=1kb)
Promoter (<=1kb)
Distal Intergenic
...

where order does not matter and list can be any arbitrary categorical variable.

""", required=True)

required.add_argument("--names", nargs='+', type=str,
                    help="""names assigned to each annotation files

""")
required.add_argument("--out_counts", help="required, path to variable counts", required=True)

##################################################
# optional args:

optional.add_argument("--order", nargs='+', type=str, 
                    help="Specifies the desired order of the variable values")
optional.add_argument("--out_proportions", help="optional, path to variable proportions")

##################################################
parser._action_groups.append(optional)
args = parser.parse_args()

# read categorical variable    
counts_list = []
for inf in args.input:
    with open(inf, "r") as f:
        this_list = [line.strip() for line in f]
        counts_list.append(Counter(this_list))

if args.order:
    order = args.order
else:
    order = sorted(set([item for d in counts_list for item in d.keys()]))
    
df = pd.DataFrame(index=order, columns=args.names)
    
for i, name in enumerate(args.names):
    df[name] = [counts_list[i][element] if element in counts_list[i] else 0 for element in order]

df.to_csv(args.out_counts, sep="\t", index=True)

if args.out_proportions:
    df = df / df.sum()
    df.to_csv(args.out_proportions, sep="\t", index=True)
