#!/usr/bin/env python3
import os
import subprocess
import argparse
import sys
import logging

# Utility to check if a tool is in PATH
def check_tool_in_path(tool_name):
    """Checks if a tool is in the system's PATH and callable."""
    result = subprocess.run(f"which {tool_name}", shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
    return result.returncode == 0

# Set up logging to a file
def setup_logging(log_file):
    """Sets up logging to a file with the appropriate format and level."""
    logging.basicConfig(
        filename=log_file,
        level=logging.INFO,
        format='%(asctime)s - %(levelname)s - %(message)s',
        datefmt='%Y-%m-%d %H:%M:%S'
    )
    # Also log to the console
    console = logging.StreamHandler()
    console.setLevel(logging.INFO)
    formatter = logging.Formatter('%(levelname)s - %(message)s')
    console.setFormatter(formatter)
    logging.getLogger('').addHandler(console)

# Log messages with different levels
def log_info(message):
    """Logs an info-level message."""
    #logging.info(message)
    print_stderr(message, error=False)

def log_error(message):
    """Logs an error-level message and prints to STDERR."""
    #logging.error(message)
    print_stderr(message)

# Print to STDERR
def print_stderr(message, error=True):
    """Prints the given message to STDERR."""
    if error:
        print(f"[ERROR] {message}", file=sys.stderr)
    else:
        print(f"[INFO] {message}", file=sys.stderr)

# (Add the logging functions here, or import them from another module)

# Check if required tools are available
def check_required_tools():
    tools = ['devider', 'lofreq', 'minimap2', 'samtools', 'bgzip', 'tabix']
    missing_tools = [tool for tool in tools if not check_tool_in_path(tool)]

    if missing_tools:
        log_error(f"The following tools are missing or not in PATH: {', '.join(missing_tools)}")
        sys.exit(1)

def try_command(command, error_message):
    log_info(f"PIPELINE - Running: {command}")
    result = subprocess.run(command, shell=True)
    if result.returncode != 0:
        log_error(error_message)
    return result.returncode == 0
# Run a command and check if it succeeded
def run_command(command, error_message):
    log_info(f"PIPELINE - Running: {command}")
    result = subprocess.run(command, shell=True)
    if result.returncode != 0:
        log_error(error_message)
        sys.exit(1)

def run_pipeline(ref, reads, output_dir, threads, devider_options, lofreq_options, args):
    if args.overwrite:
        log_info("WARNING: Overwriting existing files. This will remove the output directory if it exists.")
        if os.path.exists(output_dir):
            os.system(f"rm -r {output_dir}")
    #pipeline_folder = output_dir + "/pipeline_files"
    pipeline_folder = os.path.join(output_dir, "pipeline_files")
    if not os.path.exists(pipeline_folder):
        os.makedirs(pipeline_folder)
        #setup_logging(os.path.join(output_dir, 'run.log'))
        log_info(f"Created output directory: {pipeline_folder}")
    else:
        log_info(f"WARNING: Output directory already exists: {pipeline_folder}.")
    

    # Define intermediate file paths
    bam_file = os.path.join(pipeline_folder, 'mapping.bam')
    vcf_file = os.path.join(pipeline_folder, 'lofreq.vcf')
    vcf_gz_file = f"{vcf_file}.gz"

    # Step 1: Check if BAM file exists
    if check_file_exists(bam_file):
        run_command(
            f"minimap2 -a {ref} {reads} -t {threads} | samtools sort -@ {threads} -o {bam_file}",
            "Failed during alignment or BAM sorting"
        )

    # Step 2: Check if BAM index exists
    bam_index = bam_file + ".bai"
    if check_file_exists(bam_index):
        run_command(f"samtools index {bam_file}", "Failed to index BAM file")

    # Step 3: Check if VCF file exists
    if check_file_exists(vcf_file + '.gz'): 
        lofreq_option_string = " ".join([f"--{option}" for option in lofreq_options.split()])
        if not try_command(f"lofreq call-parallel {lofreq_option_string} --pp-threads {threads} -B -f {ref} {bam_file} -o {vcf_file}", "LoFreq parallel variant calling failed. Trying single-threaded..."):
            run_command(f"lofreq call {lofreq_option_string} -B -f {ref} {bam_file} -o {vcf_file}", "LoFreq variant calling failed")

    # Step 4: Check if compressed VCF and index exist
    if check_file_exists(vcf_gz_file) and check_file_exists(vcf_gz_file + ".tbi"):
        run_command(f"bgzip {vcf_file} -f", "Failed to compress VCF file")
        run_command(f"tabix {vcf_gz_file} -f", "Failed to index VCF file")

    # Step 5: Run dbghap
    if check_file_exists(os.path.join(output_dir, 'haplotypes.fasta')):
        devider_option_string = " ".join([f"--{option}" for option in devider_options.split()])
        run_command(f"devider -b {bam_file} -v {vcf_gz_file} -r {ref} -t {threads} {devider_option_string} -o {output_dir}",
                        "devider execution failed")

# Check if a file exists
# Disable overwriting for now... 
def check_file_exists(file_path, overwrite = False, warning_message = None):
    """Checks if a file exists and handles the overwrite option."""
    if os.path.exists(file_path):
        if overwrite:
            log_info(f"{file_path} exists but will be overwritten.")
            return True
        else:
            log_info(f"{file_path} exists. Skipping this step. Consider running `devider` by itself if you want to re-haplotype.")
            return False
    else:
        return True

def main():
    parser = argparse.ArgumentParser(description="Pipeline wrapper for devider, lofreq, and minimap2. This pipeline aligns reads to a reference genome, calls variants, and generates haplotypes. If you already have a BAM file and VCF file, please run devider directly.")
    parser.add_argument('-r', '--ref', required=True, help='Reference FASTA file')
    parser.add_argument('-i', '--reads', required=True, help='Input reads')
    parser.add_argument('-o', '--output-dir', required=True, help='Output directory')
    parser.add_argument('-t', '--threads', default=3, help='Number of threads', type=int)
    parser.add_argument('--devider-options', default="", help='Options to pass to devider, long-form only and separated by spaces (e.g. "min-coverage=10 output-reads")')
    #parser.add_argument('--overwrite', default='store_false', help='Overwrite existing files if present. WARNING: will remove the directory if it exists.')
    parser.add_argument('--overwrite', action='store_true', help='Overwrite existing files if present. WARNING: will remove the directory if it exists.')
    parser.add_argument('--lofreq-options', default="", help='Options to pass to LoFreq, long-form only and separated by spaces')
    parser.add_argument('--single-thread-lofreq', action='store_true', help='Run LoFreq with a single thread')

    args = parser.parse_args()

    
    # Log start of the pipeline
    log_info("Starting pipeline")

    # Check if required tools are available
    check_required_tools()

    # Run the pipeline
    run_pipeline(args.ref, args.reads, args.output_dir, args.threads, args.devider_options, args.lofreq_options, args)

    log_info("Pipeline finished successfully")

if __name__ == "__main__":
    main()
