#!/usr/bin/env python
"""
Usage:
    train_model (--true-pos STR) (--false-pos STR) (--type STR) [--out STR] [--njobs INT] [--verbose]

Description:
    Train a model to be saved and used to filter VCFs.

Arguments:
    --true-pos STR          Path to true-positive VCF from VCFeval or comma-seperated list of paths
    --false-pos STR         Path to false-positive VCF from VCFeval or comma-seperated list of paths
    --type STR              SNP or INDEL

Options:
    -o, --out <STR>                 Outfile name for writing model [default: (type).filter.pickle.dat]
    -n, --njobs <INT>               Number of threads to run in parallel [default: 2]
    -h, --help                      Show this help message and exit.
    -v, --version                   Show version and exit.
    --verbose                       Log output

Examples:
    train_model --true-pos <path/to/tp/vcf(s)> --false-pos <path/to/fp/vcf(s)> --type [SNP, INDEL] --njobs 20
"""

import os
os.environ["MKL_NUM_THREADS"] = "1" 
os.environ["NUMEXPR_NUM_THREADS"] = "1" 
os.environ["OMP_NUM_THREADS"] = "1" 

from docopt import docopt
import extremevariantfilter as evf
from multiprocessing import Pool
from contextlib import closing
import pandas as pd
import numpy as np
import pickle
import warnings


def get_options():
    """Gets command line arguments

    Returns
    -------
    tp_vcf : str
        path to true positive vcf
    fp_vcf : str
        path to false positive vcf
    poly : str
        type of polymorphism. 'SNP' or 'INDEL'
    njobs : int
        number of threads used for exdecution
    outname : str
        name for the output model
    """

    args = docopt(__doc__, version='1.0')
    verbose = args['--verbose']

    # Read training data
    tp_vcf = args['--true-pos']
    fp_vcf = args['--false-pos']
    poly = args['--type'].upper()
    njobs = int(args['--njobs'])
    outname = args['--out']
    if outname == "(type).filter.pickle.dat":
        outname = poly + '.filter.pickle.dat'
    evf.Check_Type(poly)

    return tp_vcf, fp_vcf, poly, njobs, outname


def main():
    """Main function for train_model

    Trains an XGBClassifier based on command line arguments.
    This model is then pickled and stored for future use.
    """

    warnings.filterwarnings('ignore',category=DeprecationWarning)

    tp_vcf, fp_vcf, poly, njobs, outname = get_options()
    all_vcf = evf.Check_VCF_Paths(tp_vcf, fp_vcf)
    with closing(Pool(processes=njobs)) as pool:
        results = pool.map(evf.Get_Training_Tables, all_vcf)
        pool.terminate()

    X, Y = zip(*results)
    X_all = np.concatenate(X)
    Y_all = np.concatenate(Y)

    model = evf.Build_Model(poly, njobs)
    print("Training {} on {}").format(model[0], all_vcf)
    model[1].fit(X_all, Y_all)
    with open(outname, "wb") as out:
        pickle.dump(model[1], out)


if __name__ == "__main__":
    main()
