#! /usr/bin/env python3

import os
import json
from collections import defaultdict
import argparse

import numpy as np
import pandas as pd

from pwas.shared_utils.util import log, get_parser_file_type, get_parser_directory_type

if __name__ == '__main__':

    parser = argparse.ArgumentParser(description = 'Pivots a list of varaints with per-gene effect scores into a per-gene data structure, ' + \
            'organized in a directory with a seperate CSV file per gene, listing all the variants affecting that gene with their effects ' + \
            'and effect scores.')
    parser.add_argument('--variants-file', dest = 'variants_file', metavar = '/path/to/variants.csv', \
            type = get_parser_file_type(parser, must_exist = True), required = True, help = 'The input list of variants (CSV format). The ' + \
            'content of the file is irrelevant for the operation of the of this process. Varaint records will simply be copied into the ' + \
            'per-gene CSV files.')
    parser.add_argument('--effects-file', dest = 'effects_file', metavar = '/path/to/variant_effects.jsonl', \
            type = get_parser_file_type(parser, must_exist = True), required = True, help = 'The gene effects determined per variant, ' + \
            'provided in JSON-lines format. Each row in the file is expected to describe the effects of the variants in the corresponding ' + \
            'row in the variants CSV file (in particular, the two files are expected to have the same number of lines, except the headers ' + \
            'line that is only expected in the CSV file, but not in the JSON-lines file). Each row in the file is expected to be a JSON-formatted ' + \
            'dictionary, mapping each gene index (a running integer index arbitrarily assigned to each gene) into the variant\'s list of ' + \
            'effects on the gene, each is a pair of i) effect description (string) and ii) effect score (float, between 0 to 1). The output of ' + \
            'this process will be a CSV file per gene, named after the provided gene index. Within this CSV file, each unique variant ' + \
            'affecting the gene will comprise a distinct row (its effect score, with respect to the gene, will be the multiplication of all ' + \
            'the effect scores assigned to the gene-variant combination). Only effects with scores smaller than 1 are considered.')
    parser.add_argument('--gene-variants-dir', dest = 'gene_variants_dir', metavar = '/path/to/output_variants_per_gene/', \
            type = get_parser_directory_type(parser), required = True, help = 'The directory into which the per-gene CSV files will be written.')
    args = parser.parse_args()

    with open(args.effects_file, 'r') as f:
        variants_effects = pd.DataFrame([(i, json.loads(line.strip())) for i, line in enumerate(f) if line.strip() != '{}'], \
                columns = ['variant_index', 'effects']).set_index('variant_index')['effects']
        log('Read the effects of %d variants.' % len(variants_effects))
    
    variant_csv_rows_to_read = {0} | {i + 1 for i in variants_effects.index}
    variants = pd.read_csv(args.variants_file, skiprows = lambda i: i not in variant_csv_rows_to_read)
    variants.index = variants_effects.index
    log('Read the %d relevant variants.' % len(variants))
    
    affecting_variants_per_gene = defaultdict(list)
    
    for i, (variant_index, variant_record) in enumerate(variants.iterrows()):
    
        if i % 10000 == 0:
            log('Processing variant %d/%d...' % (i, len(variants)), end = '\r')
    
        for gene_index, variant_gene_effects in variants_effects.loc[variant_index].items():
            
            variant_gene_effects = [(effect_description, effect_score) for effect_description, effect_score in variant_gene_effects if effect_score < 1]
            
            if variant_gene_effects:
                effect_descriptions, effect_scores = map(list, zip(*variant_gene_effects))
                total_effect_description = ' / '.join(effect_descriptions)
                total_effect_score = np.prod(effect_scores)
                gene_variant_record = variant_record.copy()
                gene_variant_record['effects_on_gene'] = total_effect_description
                gene_variant_record['effect_score'] = total_effect_score
                affecting_variants_per_gene[gene_index].append(gene_variant_record)
                
    log('Processed %d relevant effects in %d unique genes.' % (sum(map(len, affecting_variants_per_gene.values())), len(affecting_variants_per_gene)))
    
    for i, (gene_index, gene_variant_records) in enumerate(affecting_variants_per_gene.items()):
    
        if i % 100 == 0:
            log('Creating gene file %d/%d...' % (i, len(affecting_variants_per_gene)), end = '\r')
            
        gene_variants = pd.DataFrame(gene_variant_records)
        gene_variants.to_csv(os.path.join(args.gene_variants_dir, '%s.csv' % gene_index), index_label = 'variant_index')
        
    log('Done.')
            