#!/usr/bin/env python3
import pysam
import subprocess
import argparse
import os
import logging

# Set up logging
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s', level=logging.INFO)

def parse_args():
    # Set up argument parser
    parser = argparse.ArgumentParser(description="Tag BAM reads based on accessions from a given IDs file.")
    parser.add_argument('input_bam', type=str, help="Input BAM file.")
    parser.add_argument('-i', '--ids_file', type=str, default="devider_output/ids.txt", help="IDs file with accessions and tags. Default: dbghap_output/ids.txt")
    return parser.parse_args()

def check_files(input_bam, ids_file):
    # Check if input BAM file exists
    if not os.path.isfile(input_bam):
        logging.error(f"Input BAM file '{input_bam}' does not exist.")
        raise FileNotFoundError(f"Input BAM file '{input_bam}' not found.")

    # Check if ids file exists
    if not os.path.isfile(ids_file):
        logging.error(f"IDs file '{ids_file}' does not exist.")
        raise FileNotFoundError(f"IDs file '{ids_file}' not found.")

def read_ids_file(ids_file):
    # Read the ids.txt file and create a dictionary to map accessions to haplotypes
    accession_to_hp = {}
    try:
        with open(ids_file, 'r') as f:
            logging.info(f"Reading IDs from '{ids_file}'...")
            for line_number, line in enumerate(f):
                accessions = line.strip().split()
                if len(accessions) < 3:
                    logging.warning(f"Line {line_number + 1} in IDs file is malformed: {line.strip()}")
                    continue
                
                # Parse accessions and tags
                contig = accessions[0]
                hap = accessions[2].split(':')[1]
                ra = accessions[1].split(':')[1]
                tag = hap + '_' + ra

                for accession in accessions[3:]:
                    accession_to_hp[accession] = tag
    except Exception as e:
        logging.error(f"Failed to read the IDs file: {e}")
        raise

    return accession_to_hp

def tag_bam_reads(input_bam, output_bam, accession_to_hp):
    # Open the input BAM file for reading and the output BAM file for writing
    try:
        with pysam.AlignmentFile(input_bam, "rb") as in_bam, pysam.AlignmentFile(output_bam, "wb", header=in_bam.header) as out_bam:
            logging.info(f"Tagging reads in '{input_bam}' and writing to '{output_bam}'...")
            read_count = 0
            tagged_count = 0
            for read in in_bam:
                accession = read.query_name
                read_count += 1
                if accession in accession_to_hp:
                    if accession_to_hp[accession] != 'unassigned':
                        read.set_tag('HP', accession_to_hp[accession])
                        tagged_count += 1
                out_bam.write(read)

            logging.info(f"Processed {read_count} reads, tagged {tagged_count} reads.")
    except Exception as e:
        logging.error(f"Error processing BAM files: {e}")
        raise

def index_bam(output_bam):
    # Index the output BAM file
    try:
        logging.info(f"Indexing output BAM file '{output_bam}'...")
        subprocess.run(["samtools", "index", output_bam], check=True)
        logging.info(f"BAM indexing completed successfully.")
    except subprocess.CalledProcessError as e:
        logging.error(f"Failed to index BAM file: {e}")
        raise

def main():
    # Parse the command-line arguments
    args = parse_args()
    input_bam = args.input_bam
    ids_file = args.ids_file
    output_bam = input_bam + ".tagged.bam"

    # Validate input files
    check_files(input_bam, ids_file)

    # Read IDs file
    accession_to_hp = read_ids_file(ids_file)

    # Tag BAM reads
    tag_bam_reads(input_bam, output_bam, accession_to_hp)

    # Index the output BAM file
    index_bam(output_bam)

if __name__ == "__main__":
    try:
        main()
        logging.info("Script completed successfully.")
    except Exception as e:
        logging.error(f"Script failed: {e}")
        exit(1)

