#!/usr/bin/env python
"""
Usage:
    apply_filter (--vcf STR) (--snp-model STR) (--indel-model STR) [--verbose]

Description:
    Apply models generated by train_model to a VCF.

Arguments:
    --vcf STR                     VCF to be filtered
    --snp-model STR               Model for applying to SNPs
    --indel-model INT             Model for applying to InDels

Options:
    -h, --help                      Show this help message and exit.
    -v, --version                   Show version and exit.
    --verbose                       Log output

Examples:
    apply_filter --vcf <table> --snp-model <snp.pickle.dat> --indel-model <indel.pickle.dat>
"""

import os
os.environ["MKL_NUM_THREADS"] = "1" 
os.environ["NUMEXPR_NUM_THREADS"] = "1" 
os.environ["OMP_NUM_THREADS"] = "1" 

from docopt import docopt
import extremevariantfilter as evf
import pandas as pd
import numpy as np
import warnings
import pickle


def get_options():
    """Gets command line arguments

    Returns
    -------
    vcf : str
        path to vcf to be filtered
    snp_mod : str
        path to XGBClassifier trained on SNP data
    ind_mod : str
        path to XGBClassifier trained on INDEL data
    """

    args = docopt(__doc__, version='1.0')
    verbose = args['--verbose']

    # Read training data
    vcf = args['--vcf']
    evf.Check_VCF(vcf)
    snp_mod = args['--snp-model']
    ind_mod = args['--indel-model']

    return vcf, snp_mod, ind_mod


def main():
    """Main function for apply_filter

    Applies SNP and InDel specific XGBClassifiers to a variant call file.
    This VCF is updated with filter information before being written out.
    """
    warnings.filterwarnings('ignore',category=DeprecationWarning)

    vcf_path, snp_mod, ind_mod = get_options()
    header = evf.Get_Header(vcf_path)
    vcf = evf.Open_VCF(vcf_path)

    with open(snp_mod, "rb") as snp_m:
        snp_mdl = pickle.load(snp_m)
    with open(ind_mod, "rb") as ind_m:
        ind_mdl = pickle.load(ind_m)

    info_fields = pd.DataFrame(list(vcf['INFO'].apply(evf.Split_Info))).fillna(0.)
    info_fields = info_fields[['QD', 'MQ', 'FS', 'MQRankSum',
                               'ReadPosRankSum', 'SOR']]
    calls = evf.Get_Calls_Info(vcf)
    info_fields = pd.concat([info_fields, calls], axis=1)
    info_fields['Is_SNP'] = vcf.apply(evf.Check_SNP, axis=1)
    info_fields['Predict'] = info_fields.apply(evf.Predict_Var,
                                               axis=1,
                                               args=(snp_mdl, ind_mdl))
    vcf['FILTER'] = info_fields.apply(evf.Add_Filter, axis=1)

    evf.Write_VCF(vcf, header, evf.Get_Name(vcf_path))


if __name__ == "__main__":
    main()
