#! /usr/bin/env python3

import os
import re
import argparse

import numpy as np
import pandas as pd

from pwas.shared_utils.util import log, get_recognized_files_in_dir, get_chunk_slice, get_parser_file_type, get_parser_directory_type, \
        add_parser_task_arguments, determine_parser_task_details
from pwas import GenotypingManager
from pwas.gene_effect_score_aggregation import calc_gene_dominant_effect_scores, calc_gene_recessive_effect_scores

np.seterr(divide = 'ignore')

def get_genes_variant_files(gene_variants_dir):

    GENE_FILE_NAME_PATTERN = re.compile(r'^(\d+)\.csv$')
    
    def parse_gene_index(file_name):
        raw_gene_index, = GENE_FILE_NAME_PATTERN.findall(file_name)
        return int(raw_gene_index)
    
    return get_recognized_files_in_dir(gene_variants_dir, parse_gene_index)

if __name__ == '__main__':

    parser = argparse.ArgumentParser(description = 'Calculates the gene effect scores across a cohort of samples, which is a core componenet of PWAS.')
    parser.add_argument('--genotyping-spec-file', dest = 'genotyping_spec_file', metavar = '/path/to/genotyping_spec.csv', \
            type = get_parser_file_type(parser, must_exist = True), required = True, help = 'Path to a CSV file specifying the genotyping sources ' + \
            '(which are referenced in the CSV files of the variants). The file is expected to have the following two headers: "name" (the name of ' + \
            'each source) and "format" (its format; currently supporting "plink" and "bgen"). Rows of PLINK format are expected to have ' + \
            '"bed_file_path", "bim_file_path" and "fam_file_path" headers. Rows of BGEN format are expcted to have "bgen_file_path", ' + \
            '"bgi_file_path" and "sample_file_path".')
    parser.add_argument('--gene-variants-dir', dest = 'gene_variants_dir', metavar = '/path/to/variants_per_gene/', \
            type = get_parser_directory_type(parser), required = True, help = 'A directory with per-gene variants and effect scores (the variants ' + \
            'of each gene in a separate CSV file). Each CSV file is expected to have the following headers: "genotype_source_index" (the index of ' + \
            'the source within --genotyping-spec-file that contains the genotypes of this gene), "genotype_source_variant_index" (the index of ' + \
            'the variant within that source), "effect_score" (the effect score of the variant on the gene, expected to be a number between 0 to ' + \
            '1, not including 1) and, optinally, an "is_allele1_ref" column (see --is-allele1-ref-col).')
    parser.add_argument('--gene-effect-scores-dir', dest = 'gene_effect_scores_dir', metavar = '/path/to/output_gene_effect_scores/', \
            type = get_parser_directory_type(parser), required = True, help = 'The directory in which the gene effect scores will be written (a ' + \
            'separate CSV file for each gene).')
    parser.add_argument('--is-allele1-ref-col', dest = 'is_allele1_ref_col', metavar = '<is_allele1_ref_col>', type = str, default = None, \
            help = 'An optional boolean column in the variant CSV files storing whether the first allele (true) or the second allele (false) of each ' + \
            'variant is the reference allele. If not provided, the first allele will always be assumed to be the reference allele.')
    parser.add_argument('--D-u', dest = 'D_u', metavar = '[number between 0 to 1]', type = float, default = 1.0, help = 'The μ parameter for the ' + \
            'dominant model (default 1.0).')
    parser.add_argument('--D-p', dest = 'D_p', metavar = '[number between 1 to inf]', type = float, default = 1.25, help = 'The p parameter for ' + \
            'the dominant model (default 1.25).') 
    parser.add_argument('--R-u', dest = 'R_u', metavar = '[number between 0 to 1]', type = float, default = 0.5, help = 'The μ parameter for the ' + \
            'recessive model (default 0.5).')
    parser.add_argument('--R-p', dest = 'R_p', metavar = '[number between 1 to inf]', type = float, default = np.inf, help = 'The p parameter for ' + \
            'the recessive model (default inf).')
    parser.add_argument('--R-q', dest = 'R_q', metavar = '[number between 1 to inf]', type = float, default = 3.0, help = 'The q parameter for ' + \
            'the recessive model (default 3.0).') 
    parser.add_argument('--dtype', dest = 'dtype', metavar = '<numpy dtype>', type = str, default = 'float32', help = 'NumPy\'s dtype for the ' + \
            'output gene effect scores (default float32).')
    parser.add_argument('--temp-dir', dest = 'temp_dir', metavar = '/path/to/tmp/', type = get_parser_directory_type(parser), default = '/tmp', \
            help = 'A directory in which to create temporary links, if using the PLINK format (default /tmp).')
    add_parser_task_arguments(parser)
    args = parser.parse_args()
        
    task_index, total_tasks = determine_parser_task_details(args)
    
    if total_tasks != 1:
        log('Running task %d/%d.' % (task_index, total_tasks))
    
    all_gene_variant_files = get_genes_variant_files(args.gene_variants_dir)
    n_total_genes = len(all_gene_variant_files)
    gene_start, gene_end = get_chunk_slice(n_total_genes, total_tasks, task_index)
    log('Will run on genes %d-%d/%d.' % (gene_start, gene_end, n_total_genes))
        
    genotyping_spec_df = pd.read_csv(args.genotyping_spec_file)
    genotyping_manager = GenotypingManager(genotyping_spec_df, args.temp_dir)
    
    for i, (gene_index, gene_variants_file_name) in enumerate(all_gene_variant_files[gene_start:gene_end]):
    
        gene_variants = pd.read_csv(os.path.join(args.gene_variants_dir, gene_variants_file_name))
        log('Analyzing gene %d/%d (#%d; %d variants)...' % (i, gene_end - gene_start, gene_index, len(gene_variants)))
        
        genotype_source_index, = gene_variants['genotype_source_index'].unique()
        genotyping_reader = genotyping_manager.get_genotyping_reader(genotype_source_index)
        sample_ids = np.array(genotyping_reader.get_sample_ids())
        
        variant_genotypes = np.stack([genotyping_reader.get_variant_probs(genotype_source_variant_index) for genotype_source_variant_index in \
                gene_variants['genotype_source_variant_index']], axis = 0)
        effect_scores = gene_variants['effect_score'].values
        allele1_refs = np.ones(len(gene_variants), dtype = bool) if args.is_allele1_ref_col is None else gene_variants[args.is_allele1_ref_col].values
        
        pd.DataFrame({
            'sample_id': sample_ids,
            'dominant': calc_gene_dominant_effect_scores(variant_genotypes, effect_scores, allele1_refs, args.D_u, args.D_p, dtype = args.dtype),
            'recessive': calc_gene_recessive_effect_scores(variant_genotypes, effect_scores, allele1_refs, args.R_u, args.R_p, args.R_q, dtype = args.dtype),
        }).to_csv(os.path.join(args.gene_effect_scores_dir, '%d.csv' % gene_index), index = False)

    log('Done.')