#! /usr/bin/env python3

import os
import argparse

import pandas as pd

from pwas.shared_utils.util import get_parser_file_type, get_parser_directory_type, multipletests_with_nulls

if __name__ == '__main__':

    parser = argparse.ArgumentParser(description = 'Combines the per-gene PWAS results (obtained by the pwas_test_genes command), runs FDR, and sorts ' + \
            'the results.')
    parser.add_argument('--genes-file', dest = 'genes_file', metavar = '/path/to/gene_dataset.csv', type = get_parser_file_type(parser, must_exist = True), \
            required = True, help = 'Specification of all the relevant genes. Each row in the CSV file is a gene. The first column is expected to be ' + \
            'the index of each gene (which is used for the names of the per-gene CSV files in the PWAS results directory).')
    parser.add_argument('--per-gene-pwas-results-dir', dest = 'per_gene_pwas_results_dir', metavar = '/path/to/per_gene_pwas_results/', \
            type = get_parser_directory_type(parser), required = True, help = 'The directory with the per-gene results (a CSV file per gene, with ' + \
            'the relevant summary statistics).')
    parser.add_argument('--results-file', dest = 'results_file', metavar = '/path/to/pwas_results.csv', type = get_parser_file_type(parser), \
            required = True, help = 'The CSV file to write the combined PWAS results into.')
    args = parser.parse_args()
    
    genes = pd.read_csv(args.genes_file, index_col = 0)
    per_gene_results = []
    
    for gene_index in genes.index:
        
        gene_results_file_path = os.path.join(args.per_gene_pwas_results_dir, '%d.csv' % gene_index)
        
        if os.path.exists(gene_results_file_path):
            per_gene_results.append(pd.read_csv(gene_results_file_path, header = None, index_col = 0, squeeze = True).rename(gene_index))
            
    per_gene_results = pd.DataFrame(per_gene_results)
    print('Collected PWAS results for %d of %d genes.' % (len(per_gene_results), len(genes)))
    
    results = pd.concat([genes, per_gene_results], axis = 1)
    results.index.name = 'gene_index'
    
    results.sort_values('combined_pval', inplace = True)
    results['fdr_significance'], results['fdr_qval'] = multipletests_with_nulls(results['combined_pval'], method = 'fdr_bh')
    print('There are %d significant genes (after FDR).' % results['fdr_significance'].sum())
    
    results.to_csv(args.results_file)