#!/usr/bin/env python
import pandas as pd
import matplotlib.pyplot as plt
import argparse
import json
import numpy as np

parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter, \
description="""

Use logFC and logCPM dataframes to plot an MA plot. Circle genes/sites of
interest, if desired.

""")

##################################################
# required args:
parser.add_argument("-m", "--M", help="""required, file path to logFC dataframe
                    """)
parser.add_argument("-a", "--A", help="""required, file path to logCPM dataframe
                    """)
parser.add_argument("-f", "--FDR", help="""required, file path to FDR dataframe
                    """)
parser.add_argument("-o", "--outplot", help="""required, file path to FDR dataframe
                    """)

###################################################

parser.add_argument("--gene_ids", nargs="+", help="optional, circle the points corresponding to --gene_ids")
parser.add_argument("--edgeR_results_file", help="optional, instead of providing -m -a and -f, provide edgeR results file")

###################################################

args = parser.parse_args()

if not args.edgeR_results_file:
    logFC = pd.read_csv(args.M, sep="\t", index_col=0)
    logCPM = pd.read_csv(args.A, sep="\t", index_col=0)
    FDR = pd.read_csv(args.FDR, sep="\t", index_col=0)
    assert(all(logFC.index == logCPM.index))
    assert(all(logFC.index == FDR.index))
else:
    edgeR_results_file = pd.read_csv(args.edgeR_results_file, sep="\t", index_col=0)
    logFC = edgeR_results_file["logFC"]
    logCPM = edgeR_results_file["logCPM"]
    FDR = edgeR_results_file["FDR"]

# with open('/data/reddylab/Reference_Data/Gencode/v22/gencode.v22.gene_id_to_gene_name.json', 'r') as f:    
#     gene_id_to_gene_name = json.load(f)

# logFC.index = [gene_id_to_gene_name[x] for x in list(logFC.index)]

fig,ax = plt.subplots(figsize=(3,3))

ax.scatter(np.array(logCPM)[np.array(FDR) > 0.2],
           np.array(logFC)[np.array(FDR) > 0.2],
           rasterized=True, s=1, alpha=0.1,
           color="black")

ax.scatter(np.array(logCPM)[(np.array(FDR) <= 0.2) & (np.array(FDR) > 0.1)],
           np.array(logFC)[(np.array(FDR) <= 0.2) & (np.array(FDR) > 0.1)],
           rasterized=True, s=1,
           color="#3F0000")
ax.scatter(np.array(logCPM)[(np.array(FDR) <= 0.1) & (np.array(FDR) > 0.05)],
           np.array(logFC)[(np.array(FDR) <= 0.1) & (np.array(FDR) > 0.05)],
           rasterized=True, s=1,
           color="#7F0000")
ax.scatter(np.array(logCPM)[(np.array(FDR) <= 0.05) & (np.array(FDR) > 0.01)],
           np.array(logFC)[(np.array(FDR) <= 0.05) & (np.array(FDR) > 0.01)],
           rasterized=True, s=1,
           color="#BF0000")
ax.scatter(np.array(logCPM)[(np.array(FDR) <= 0.01)],
           np.array(logFC)[(np.array(FDR) <= 0.01)],
           rasterized=True, s=1,
           color="#FF0000")

if args.gene_ids:
    for gene_id in args.gene_ids:
        ax.scatter(np.array(logCPM.ix[gene_id]),
                   np.array(logFC.ix[gene_id]),
                   s=75, marker="o", facecolors="none",
                   color="green")

ax.set_xlim((np.floor(np.array(logCPM).min()), 
             np.ceil(np.array(logCPM).max())))
ax.set_ylim((np.floor(np.array(logFC).min()), 
             np.ceil(np.array(logFC).max())))

ax.axhline(0, c="blue")
ax.set_ylabel('M')
ax.set_xlabel('A')

plt.tight_layout()
plt.savefig(args.outplot)