#! /usr/bin/env python3

import os
import sys
import argparse

import numpy as np
import pandas as pd

from pwas.shared_utils.util import log, get_parser_file_type, get_parser_directory_type
from pwas.shared_utils.reference_genome import GenomeReader

N_PROCESSED_VARIANTS_PER_PROGRESS_UPDATE = 1000

def is_allele1_ref(genome_reader, chrom, pos, allele1, allele2, mismatch_action):
    
    seq = genome_reader.read_seq(chrom, pos, pos + max(len(allele1), len(allele2)) - 1)
    
    if seq.startswith(allele1):
        return True
    elif seq.startswith(allele2):
        return False
    else:
    
        error_message = 'Neither of the two alleles of variant chr%s:%d[%s,%s] match the reference genome at that position (%s).' % \
                (chrom, pos, allele1, allele2, seq)
    
        if mismatch_action == 'error':
            raise ValueError(error_message)
        elif mismatch_action == 'warning':
            print('Warning: ' + error_message)
        
        return np.nan 

if __name__ == '__main__':

    parser = argparse.ArgumentParser(description = 'For each variant in a given list, determine which of the two alleles is the reference allele.')
    parser.add_argument('--variants-file', dest = 'variants_file', metavar = '/path/to/variants.csv', \
            type = get_parser_file_type(parser, must_exist = True), required = True, help = 'The input list of variants (CSV format).')
    parser.add_argument('--ref-genome-dir', dest = 'ref_genome_dir', metavar = '/path/to/ref_genome_dir/', \
            type = get_parser_directory_type(parser), required = True, help = 'Path to the reference genome directory, expected to have one ' + \
            '(uncompressed) FASTA file per chromosome. Make sure that the used reference genome corresponds to the given variants.')
    parser.add_argument('--chrom-col', dest = 'chrom_col', metavar = '<csv_col_name>', type = str, default = 'chrom', help = 'The column in the ' + \
            'input CSV file storing the chromosome name of each variant.')
    parser.add_argument('--pos-col', dest = 'pos_col', metavar = '<pos_col_name>', type = str, default = 'pos', help = 'The column in the ' + \
            'input CSV file storing the chromosomal position/coordinate of each variant.')
    parser.add_argument('--allele1-col', dest = 'allele1_col', metavar = '<allele1_col_name>', type = str, default = 'allele1', help = 'The ' + \
            'column in the input CSV file storing the first allele of each variant.')
    parser.add_argument('--allele2-col', dest = 'allele2_col', metavar = '<allele2_col_name>', type = str, default = 'allele2', help = 'The ' + \
            'column in the input CSV file storing the second allele of each variant.')
    parser.add_argument('--result-col', dest = 'result_col', metavar = '<result_col_name>', type = str, default = 'is_allele1_ref', help = 'The ' + \
            'name of the new boolean column to add to the CSV file, that will determine, for each variant, whether the first allele is the reference ' + \
            'allele, or the second allele (default: is_allele1_ref).')
    parser.add_argument('--output-file', dest = 'output_file', metavar = '/path/to/output_variants.csv', type = get_parser_file_type(parser), \
            help = 'Path to the output CSV file, where all the input variants will be stored, with one additional column (by default, will override ' + \
            'the input --variants-file).')
    parser.add_argument('--override', dest = 'override', action = 'store_true', help = 'Whether to allow overriding an existing file.')
    parser.add_argument('--verbose', dest = 'verbose', action = 'store_true', help = 'Whether to print information regarding the progress of this operation.')
    parser.add_argument('--mismatch-action', dest = 'mismatch_action', choices = ['error', 'warning', 'ignore'], default = 'warning', \
            help = 'The action to be taken if neither of the two alleles in a variant match the reference genome (default: warning).')
    parser.add_argument('--unknown-chrom-action', dest = 'unknown_chrom_action', choices = ['error', 'warning', 'ignore'], default = 'warning', \
            help = 'The action to be taken if faced with an unrecognized chromosome, not appearing in the provided reference genome (default: warning).')
    args = parser.parse_args()
    
    if args.output_file is None:
        output_file = args.variants_file
    else:
        output_file = args.output_file
        
    if not args.override and os.path.exists(output_file):
        log('Output file already exists (specify --override if you wish to override it): %s' % output_file)
        sys.exit(1)
        
    genome_reader = GenomeReader(args.ref_genome_dir)
    variants = pd.read_csv(args.variants_file, dtype = {args.chrom_col: str})
    unrecognized_chroms = {chrom for chrom in variants[args.chrom_col].unique() if chrom not in genome_reader}
    
    if unrecognized_chroms:
    
        error_message = 'There are %d variants with %d unrecognized chromosomes: %s.' % (variants[args.chrom_col].isin(unrecognized_chroms).sum(), \
                len(unrecognized_chroms), ', '.join(sorted(unrecognized_chroms)))
    
        if args.unknown_chrom_action == 'error':
            raise Exception(error_message)
        elif args.unknown_chrom_action == 'warning':
            print('Warning: ' + error_message)
    
    is_allele1_ref_results = []
    
    for i, variant in variants.iterrows():
        
        if variant[args.chrom_col] in unrecognized_chroms:
            is_allele1_ref_results.append(np.nan)
        else:
            is_allele1_ref_results.append(is_allele1_ref(genome_reader, variant[args.chrom_col], variant[args.pos_col], \
                    variant[args.allele1_col], variant[args.allele2_col], args.mismatch_action))
    
        if args.verbose and (i + 1) % N_PROCESSED_VARIANTS_PER_PROGRESS_UPDATE == 0:
            log('Processed %d of %d variants.' % (i + 1, len(variants)), end = '\r')
    
    variants[args.result_col] = is_allele1_ref_results
    genome_reader.close()
    
    if args.verbose:
        log('Saving results into %s...' % output_file)
    
    variants.to_csv(output_file, index = False)