#! /usr/bin/env python3

import sys
import os
import re
import json
import argparse

import numpy as np
import pandas as pd

from pwas.statistics import GeneTest, BinaryTraitGeneScoreTest, ContinuousTraitGeneScoreTest
from pwas.shared_utils.util import log, get_recognized_files_in_dir, get_chunk_slice, get_parser_file_type, get_parser_directory_type, \
        add_parser_task_arguments, determine_parser_task_details, normalize, is_binary_series, find_linearly_independent_columns, \
        resolve_quasi_complete_separation_by_removing_binary_columns
        
MIN_SAMPLES_COVARIATES_RATIO_WARNING = 10
        
def determine_covariate_cols(args):

    covariate_cols = set(args.covariate_cols)
    
    if args.covariate_cols_json_file is not None:
        with open(args.covariate_cols_json_file, 'r') as f:
            covariate_cols |= set(json.load(f))
            
    if covariate_cols:
        log('%d covariates will be included.' % len(covariate_cols))
    else:
        log('Warning: No covariates were provided! Are you sure about that?!')
        
    return list(sorted(covariate_cols))
    
def determine_score_test_class(phenotype_values):
    if is_binary_series(phenotype_values):
        log('The phenotype (%s) was determined to be binary (%d cases).' % (args.phenotype_col, phenotype_values.sum()))
        return BinaryTraitGeneScoreTest
    else:
        log('The phenotype (%s) was determined to be continuous.' % args.phenotype_col)
        return ContinuousTraitGeneScoreTest
    
def validate_n_samples(phenotype_values, n_covariates):
    
    if is_binary_series(phenotype_values):
        
        n_controls = (phenotype_values == 0).sum()
        n_cases = (phenotype_values == 1).sum()
        log('%d cases and %d controls.' % (n_cases, n_controls))
        
        if n_cases <= n_controls:
            n_effective_samples = n_cases
            effective_samples_name = 'cases'
        else:
            n_effective_samples = n_controls
            effective_samples_name = 'controls'
    else:
        n_effective_samples = len(phenotype_values)
        effective_samples_name = 'samples'

    if n_effective_samples < n_covariates * MIN_SAMPLES_COVARIATES_RATIO_WARNING:
        log('Warning: the number of %s (%d) is small given the number of covariates (%d).' % (effective_samples_name, n_effective_samples, \
                n_covariates))
                
def handle_covariate_separation(phenotype_values, covariates, remove_multicollinear_covariates, resolve_quasi_complete_covariate_separation):
    
    linearly_independent_covariate_column_indices = find_linearly_independent_columns(covariates.values)
    
    if list(linearly_independent_covariate_column_indices) != list(range(len(covariates.columns))):
        if remove_multicollinear_covariates:
            log('The %d provided covariates are multicollinear. Will retain only %d of them.' % (len(covariates.columns), \
                    len(linearly_independent_covariate_column_indices)))
            covariates = covariates.iloc[:, linearly_independent_covariate_column_indices]
        else:
            log(('Error: the %d provided covariates are multicollinear. Either provide only a subset of these covariates that are linearly ' + \
                    'independent, or use the flag --remove-multicollinear-covariates.') % len(covariates.columns))
            sys.exit(1)
            
    covariates, phenotype_values, _, removed_covariate_columns, _ = resolve_quasi_complete_separation_by_removing_binary_columns(covariates, \
            phenotype_values)
            
    if len(removed_covariate_columns) > 0:
        if resolve_quasi_complete_covariate_separation:
            log('Remained with %d covariates and %d samples.' % covariates.shape[::-1])
        else:
            log(('Error: the %d covariates introduce quasi-complete separation with resepct to the phenotype. To allow automatic resolving ' + \
                    'of this issue (as detailed above), you must use the flag --resolve-quasi-complete-covariate-separation.') % len(covariates.columns))
            sys.exit(1)
            
    return phenotype_values, covariates
        
def get_gene_effect_score_files(gene_effect_scores_dir):

    GENE_FILE_NAME_PATTERN = re.compile(r'^(\d+)\.csv$')
    
    def parse_gene_index(file_name):
        raw_gene_index, = GENE_FILE_NAME_PATTERN.findall(file_name)
        return int(raw_gene_index)
    
    return get_recognized_files_in_dir(gene_effect_scores_dir, parse_gene_index)

if __name__ == '__main__':

    parser = argparse.ArgumentParser(description = 'Run the PWAS statistical tests per gene (looking for associations between a phenotype to gene ' + \
            'effect scores, while potentially accounting for covariates).')
    parser.add_argument('--dataset-file', dest = 'dataset_file', metavar = '/path/to/dataset.csv/', type = get_parser_file_type(parser, must_exist = True), \
            required = True, help = 'Path to the dataset CSV file of the cohort (a row per sample, with eid, phenotype and covariate columns).')
    parser.add_argument('--gene-effect-scores-dir', dest = 'gene_effect_scores_dir', metavar = '/path/to/gene_effect_scores/', \
            type = get_parser_directory_type(parser), required = True, help = 'The directory with the gene effect scores (with a separate CSV file ' + \
            'per gene).')
    parser.add_argument('--per-gene-pwas-results-dir', dest = 'per_gene_pwas_results_dir', metavar = '/path/to/output_per_gene_pwas_results/', \
            type = get_parser_directory_type(parser), required = True, help = 'The directory in which the results will be written (a CSV file per ' + \
            'gene, with the relevant summary statistics).')
    parser.add_argument('--sample-id-col', dest = 'sample_id_col', metavar = '<COL_NAME>', type = str, required = True, \
            help = 'The name of the column (within the dataset CSV file) listing the identifier of each sample (which is expected to match to the ' + \
            '"sample_id" column in the gene effect score CSV files).')
    parser.add_argument('--phenotype-col', dest = 'phenotype_col', metavar = '<COL_NAME>', type = str, required = True, \
            help = 'The name of the column (within the dataset CSV file) for the phenotype to be tested. The phenotype can be either continuous or ' + \
            'binary (will be automatically detected if the values are only 0s and 1s).')
    parser.add_argument('--covariate-cols', dest = 'covariate_cols', metavar = '<COL_NAME1>,<COL_NAME2>,...', type = str, nargs = '+', default = [], \
            help = 'Columns (within the dataset CSV file) to be considered as covariates. See also --covariate-cols-json-file.')
    parser.add_argument('--covariate-cols-json-file', dest = 'covariate_cols_json_file', metavar = '/path/to/covariate_columns.json', \
            type = get_parser_file_type(parser, must_exist = True), default = None, help = 'An optional JSON file with a list of column names to ' + \
            'consider as covariates. If both --covariate-cols and --covariate-cols-json-file are provided, will take the union of the provided column names.')
    parser.add_argument('--remove-multicollinear-covariates', dest = 'remove_multicollinear_covariates', action = 'store_true', help = 'If the provided ' + \
            'covariates are multicollinear, this will automatically remove some of them to retain independence of the remaining ones.')
    parser.add_argument('--resolve-quasi-complete-covariate-separation', dest = 'resolve_quasi_complete_covariate_separation', action = 'store_true', \
            help = 'If the phenotype is binary and quasi-complete separates the covariates (i.e. if the covariates are multicollinear among cases), ' + \
            'this will automatically remove binary covariates (and their associated samples) to resolve this. For details, see the help of the ' + \
            'function shared_utils.util.resolve_quasi_complete_separation_by_removing_binary_columns.')
    add_parser_task_arguments(parser)
    args = parser.parse_args()
            
    task_index, total_tasks = determine_parser_task_details(args)
    covariate_cols = determine_covariate_cols(args)
        
    dataset = pd.read_csv(args.dataset_file, index_col = args.sample_id_col)
    log('Loaded the dataset (%s): %d samples x %d columns.' % (args.dataset_file, len(dataset), len(dataset.columns)))
    dataset.dropna(subset = [args.phenotype_col] + covariate_cols, inplace = True)
    log('Filtered the dataset into %d samples with the phenotype and all covariates.' % len(dataset))
    
    phenotype_values = dataset[args.phenotype_col]
    covariates = dataset[covariate_cols]
    score_test_class = determine_score_test_class(phenotype_values)
    phenotype_values, covariates = handle_covariate_separation(phenotype_values, covariates, args.remove_multicollinear_covariates, \
            args.resolve_quasi_complete_covariate_separation)
    covariates = covariates.apply(normalize)
    validate_n_samples(phenotype_values, len(covariates.columns))
    
    if total_tasks != 1:
        log('Running task %d/%d.' % (task_index, total_tasks))
        
    all_gene_effect_score_files = get_gene_effect_score_files(args.gene_effect_scores_dir)
    n_total_genes = len(all_gene_effect_score_files)
    gene_start, gene_end = get_chunk_slice(n_total_genes, total_tasks, task_index)
    log('Will run on genes %d-%d/%d.' % (gene_start, gene_end, n_total_genes))
    
    for i, (gene_index, gene_effect_score_file_name) in enumerate(all_gene_effect_score_files[gene_start:gene_end]):
    
        gene_effect_scores = pd.read_csv(os.path.join(args.gene_effect_scores_dir, gene_effect_score_file_name))
        gene_sample_ids = set(gene_effect_scores['sample_id'])
        gene_relevant_sample_ids = np.array([sample_id for sample_id in phenotype_values.index if sample_id in gene_sample_ids])
        log('Analyzing gene %d/%d (#%d; %d of %d samples are in the dataset)...' % (i, gene_end - gene_start, gene_index, \
                len(gene_relevant_sample_ids), len(phenotype_values)))
                
        gene_relevant_phenotype_values = phenotype_values.loc[gene_relevant_sample_ids]
        gene_relevant_covariates = covariates.loc[gene_relevant_sample_ids]
        gene_relevant_effect_scores = gene_effect_scores.set_index('sample_id').loc[gene_relevant_sample_ids]
                
        gene_results = GeneTest(gene_relevant_phenotype_values, gene_relevant_covariates, gene_relevant_effect_scores['dominant'], \
                gene_relevant_effect_scores['recessive'], score_test_class).run()
        gene_results.to_csv(os.path.join(args.per_gene_pwas_results_dir, '%d.csv' % gene_index), header = False)

    log('Done.')
