#! /usr/bin/env python
"""
Python modules required:
1> pandas
2> numpy
3> matplotlib
4> argparse
5> seaborn


Input Files:
1> 'IlluminaChrCord.txt' which contains the chromosome cordinates of the genes
2> 'GenesFPKM.table' which contains RNA seq expression values

Output File:
1>ExpChrDist.pdf

The output file contains distribution of the mean expression values of genes in contiguous
windows along the chromosomes.
"""

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import argparse
import seaborn as sns
#######################  get the input file ####################################
parser = argparse.ArgumentParser(description='Maps data distribution on the PacBio chromosome given the illumina gene', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('-i', '--in_fol', default = './', help = 'folder that has all pairwise Paml results') 
#, action='append' - This stores a list, and appends each argument value to the list. This is useful to allow an option to be specified multiple times. 
parser.add_argument('-c', '--IlluminaChrCord' , default = 'IlluminaChrCord.txt', help = 'File that has chrosmosome cordinates of the genes')
parser.add_argument('-w', '--WindowSize' , default = 500000, help = 'Size of the non-ovelapping chromosome windows')
parser.add_argument('-O', '--OutFol', default = './')
parser.add_argument('-E', '--ExpFile' , default = 'GenesFPKM.table', help = 'file with PS312 RNA seq expression values')
args = vars(parser.parse_args()) #parse arguments 
#XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX


ChrCorDf = pd.read_table(args['IlluminaChrCord'], sep='\t', header=None, usecols=[0,1,2,3,6], names=['Gene', 'Chr', 'Start', 'end', 'AgeClass'])

#Generate non-ovelapping windows
ChrCorDf['WindowNo'] = ((ChrCorDf.Start + ChrCorDf.end)/(2*args['WindowSize'])).astype(int)*args['WindowSize']
AgeClassDistDf = pd.DataFrame(pd.crosstab([ChrCorDf.Chr, ChrCorDf.WindowNo], ChrCorDf.AgeClass))
AgeClassDistDf = AgeClassDistDf.divide(AgeClassDistDf.sum(axis=1)/100, axis=0)
AgeClassDistDf = AgeClassDistDf.reset_index(level=0)
Chromosomes = AgeClassDistDf.Chr.unique()[0:6]




#Set PS312 gene id as the index
ChrCorDf = ChrCorDf.set_index('Gene')

#Extract the expression value of genes in all samples
ExpDf = pd.read_table(args['ExpFile'], sep='\t', index_col='Gene')
ExpDf = ExpDf.replace(np.nan, value=0)

# calculate the mean expression level of all genes
ExpDf['S_mean'] = ExpDf.mean(axis=1)

########################## Mean Expression Chromosome Map ####################
for Cols in ExpDf.columns:
    if Cols != 'S_mean':continue 
    ExpClassDf = ExpDf.copy()
    ExpClassDf = pd.DataFrame(ExpClassDf.loc[:,Cols])
    #divide the non-zero mean expression values into three categories
    # mean FPKM >= 10
    ExpClassDf[ExpClassDf >= 10] = '>= 10'
    # mean FPKM between 1 and 10
    ExpClassDf[(ExpClassDf > 0) & (ExpClassDf <=1)] = '0 - 1'    
    # mean FPKM between 0 and 1
    ExpClassDf[ (ExpClassDf > 1) & (ExpClassDf < 10)] = '1 - 10'
    ExpChrDf = pd.concat([ChrCorDf.Chr, ChrCorDf.WindowNo, ChrCorDf.Start, ChrCorDf.AgeClass, ExpClassDf[Cols]], axis=1)
    ExpClassDistDf = pd.DataFrame(pd.crosstab([ExpChrDf.Chr, ExpChrDf.WindowNo], ExpChrDf[Cols]))
    ExpClassDistDf = ExpClassDistDf.divide(ExpClassDistDf.sum(axis=1)/100, axis=0)
    ExpClassDistDf = ExpClassDistDf.reset_index(level=0)

    n=1
    
    plt.close('all')
    plt.clf()
    plt.rcParams['pdf.fonttype'] = 'truetype'
    PdfName = '%s/ExpChrDist.pdf'%args['OutFol']
    plt.figure(figsize=(20, 18))
    for Chr in Chromosomes:
        ChrDf = ExpClassDistDf[ExpClassDistDf.Chr==Chr]
        ChrDf = ChrDf.drop('Chr', axis=1)
        ChrDf = ChrDf.loc[:, ::-1]
        print n, len(Chromosomes)
        sns.set(context='poster')
        plt.subplot(len(Chromosomes), 1, n)
        ChrDf.plot(kind='bar', stacked=True, ax=plt.gca(), color=['#DC90BE', '#4DABB3', '#9EC687', '#D4D8D0'], legend=False, ylim = (0,100), width=0.6)
        plt.xlabel("")
        plt.tick_params(
            axis='x',          # changes apply to the x-axis
            which='both',      # both major and minor ticks are affected
            bottom='off',      # ticks along the bottom edge are off
            top='off', # ticks along the top edge are off
            labelbottom='off') # labels along the bottom edge are off
        plt.tick_params(
            axis='y',          # changes apply to the y-axis
            which='both',      # both major and minor ticks are affected
            left='off',      # ticks along the bottom edge are off
            right='off', # ticks along the top edge are off
            labelleft='off') # labels along the left edge are off
        n += 1
        plt.ylabel(Chr)
    plt.legend(title="mean FPKM", loc='center left', bbox_to_anchor=(1, 0.5))
    plt.savefig(PdfName)