#!/usr/bin/env python3

__version__ = '1.0.2'
__author__ = 'David Heller'

import sys
import os
import re
import pickle
import gzip
import logging
import pysam

from time import strftime, localtime
from pysam import FastaFile

from SVIM_input_parsing import parse_arguments
from SVIM_COLLECT import analyze_alignment_file_coordsorted
from SVIM_COMBINE import pair_candidates, write_final_vcf
from SVIM_plot import plot_sv_lengths


def main():
    # Fetch command-line options
    options = parse_arguments(program_version=__version__)

    if not options.sub:
        print("Please choose one of the two modes ('haploid' or 'diploid'). See --help for more information.")
        return

    # Set up logging
    logFormatter = logging.Formatter("%(asctime)s [%(levelname)-7.7s]  %(message)s")
    rootLogger = logging.getLogger()
    if options.verbose:
        rootLogger.setLevel(logging.DEBUG)
    else:
        rootLogger.setLevel(logging.INFO)

    # Create working dir if it does not exist
    if not os.path.exists(options.working_dir):
        os.makedirs(options.working_dir)

    # Create log file
    fileHandler = logging.FileHandler("{0}/SVIM_{1}.log".format(options.working_dir, strftime("%y%m%d_%H%M%S", localtime())), mode="w")
    fileHandler.setFormatter(logFormatter)
    rootLogger.addHandler(fileHandler)

    consoleHandler = logging.StreamHandler()
    consoleHandler.setFormatter(logFormatter)
    rootLogger.addHandler(consoleHandler)

    logging.info("****************** Start SVIM-asm, version {0} ******************".format(__version__))
    logging.info("CMD: python3 {0}".format(" ".join(sys.argv)))
    logging.info("WORKING DIR: {0}".format(os.path.abspath(options.working_dir)))
    for arg in vars(options):
        logging.info("PARAMETER: {0}, VALUE: {1}".format(arg, getattr(options, arg)))

    logging.info("****************** STEP 1: COLLECT ******************")
    
    if options.sub == 'haploid':
        logging.info("MODE: haploid")
        logging.info("INPUT: {0}".format(os.path.abspath(options.bam_file)))
        aln_file1 = pysam.AlignmentFile(options.bam_file)
        try:
            if aln_file1.header["HD"]["SO"] == "coordinate":
                try:
                    aln_file1.check_index()
                except ValueError:
                    logging.error("Input BAM file is missing an index. Please generate with 'samtools index'. Exiting..")
                    return
                except AttributeError:
                    logging.error("pysam's .check_index raised an Attribute error. Something is wrong with the input BAM file. Exiting..")
                    return
                sv_candidates = analyze_alignment_file_coordsorted(aln_file1, options)
            else:
                logging.error("Input BAM file needs to be coordinate-sorted. Exiting..")
                return
        except KeyError:
            logging.error("Is the given input BAM file coordinate-sorted? It does not contain a sorting order in its header line. Exiting..")
            return
    elif options.sub == 'diploid':
        logging.info("MODE: diploid")
        logging.info("INPUT1: {0}".format(os.path.abspath(options.bam_file1)))
        logging.info("INPUT2: {0}".format(os.path.abspath(options.bam_file2)))
        aln_file1 = pysam.AlignmentFile(options.bam_file1)
        aln_file2 = pysam.AlignmentFile(options.bam_file2)
        try:
            if aln_file1.header["HD"]["SO"] == "coordinate":
                try:
                    aln_file1.check_index()
                except ValueError:
                    logging.error("The first input BAM file is missing an index. Please generate with 'samtools index'. Exiting..")
                    return
                except AttributeError:
                    logging.error("pysam's .check_index raised an Attribute error. Something is wrong with the first input BAM file. Exiting..")
                    return
                sv_candidates1 = analyze_alignment_file_coordsorted(aln_file1, options)    
            else:
                logging.error("The first input BAM file needs to be coordinate-sorted. Exiting..")
                return
        except KeyError:
            logging.error("Is the given first input BAM file coordinate-sorted? It does not contain a sorting order in its header line.")
            return
        try:
            if aln_file2.header["HD"]["SO"] == "coordinate":
                try:
                    aln_file2.check_index()
                except ValueError:
                    logging.error("The second input BAM file is missing an index. Please generate with 'samtools index'. Exiting..")
                    return
                except AttributeError:
                    logging.error("pysam's .check_index raised an Attribute error. Something is wrong with the second input BAM file. Exiting..")
                    return
                sv_candidates2 = analyze_alignment_file_coordsorted(aln_file2, options)
            else:
                logging.error("The second input BAM file needs to be coordinate-sorted. Exiting..")
                return
        except KeyError:
            logging.error("Is the given second input BAM file coordinate-sorted? It does not contain a sorting order in its header line. Exiting..")
            return

    # Open reference genome sequence file
    try:
        reference = FastaFile(options.genome)
    except ValueError:
        logging.error("The given reference genome is missing an index file ({path}.fai). Sequence alleles cannot be retrieved.".format(options.genome))
        return
    except IOError:
        logging.error("The given reference genome is missing ({path}). Sequence alleles cannot be retrieved.".format(options.genome))
        return

    if options.sub == 'haploid':
        deletion_candidates = [cand for cand in sv_candidates if cand.type == "DEL"]
        insertion_candidates = [cand for cand in sv_candidates if cand.type == "INS"]
        inversion_candidates = [cand for cand in sv_candidates if cand.type == "INV"]
        tandem_duplication_candidates = [cand for cand in sv_candidates if cand.type == "DUP_TAN"]
        breakend_candidates = [cand for cand in sv_candidates if cand.type == "BND"]
        interspersed_duplication_candidates = [cand for cand in sv_candidates if cand.type == "DUP_INT"]
    elif options.sub == 'diploid':
        logging.info("****************** STEP 2: PAIR ******************")
        paired_candidates = pair_candidates(sv_candidates1, sv_candidates2, reference, aln_file1, options)

        deletion_candidates = [cand for cand in paired_candidates if cand.type == "DEL"]
        insertion_candidates = [cand for cand in paired_candidates if cand.type == "INS"]
        inversion_candidates = [cand for cand in paired_candidates if cand.type == "INV"]
        tandem_duplication_candidates = [cand for cand in paired_candidates if cand.type == "DUP_TAN"]
        breakend_candidates = [cand for cand in paired_candidates if cand.type == "BND"]
        interspersed_duplication_candidates = [cand for cand in paired_candidates if cand.type == "DUP_INT"]

    if options.sub == 'haploid':
        logging.info("****************** STEP 2: OUTPUT ******************")
    elif options.sub == 'diploid':
        logging.info("****************** STEP 3: OUTPUT ******************")

    logging.info("Found {0} deletion candidates.".format(len(deletion_candidates)))
    logging.info("Found {0} inversion candidates.".format(len(inversion_candidates)))
    logging.info("Found {0} insertion candidates.".format(len(insertion_candidates)))
    logging.info("Found {0} tandem duplication candidates.".format(len(tandem_duplication_candidates)))
    logging.info("Found {0} interspersed duplication candidates.".format(len(interspersed_duplication_candidates)))
    logging.info("Found {0} breakend candidates.".format(len(breakend_candidates)))

    # Write SV candidates
    logging.info("Write SV candidates..")
    types_to_output = [entry.strip() for entry in options.types.split(",")]
    write_final_vcf(interspersed_duplication_candidates,
                    inversion_candidates,
                    tandem_duplication_candidates,
                    deletion_candidates,
                    insertion_candidates,
                    breakend_candidates,
                    __version__,
                    aln_file1.references,
                    aln_file1.lengths,
                    types_to_output,
                    reference,
                    options)
    logging.info("Draw plots..")
    plot_sv_lengths(deletion_candidates, inversion_candidates, interspersed_duplication_candidates, tandem_duplication_candidates, insertion_candidates, options)
    logging.info("Done.")

if __name__ == "__main__":
    try:
        sys.exit(main())
    except Exception as e:
        logging.error(e, exc_info=True)
