import pysam
import matplotlib.pyplot as plt
import re


def parse_cigar(cigar_tuples):
    """Parse CIGAR tuples to extract operations and their lengths."""
    operations = []
    for (op, length) in cigar_tuples:
        if op == 0:  # Match or mismatch
            operations.append((length, 'M'))
        elif op == 1:  # Insertion
            operations.append((length, 'I'))
        elif op == 2:  # Deletion
            operations.append((length, 'D'))
        elif op == 4:  # Soft clipping
            operations.append((length, 'S'))
        elif op == 5:  # Hard clipping
            operations.append((length, 'H'))
        elif op == 7:  # Match
            operations.append((length, '='))
        elif op == 8:  # Mismatch
            operations.append((length, 'X'))
    return operations


def get_sequence_length_from_cigar(cigar_str):
    """Parse CIGAR string to calculate the sequence length."""
    pattern = re.compile(r'(\d+)([MIDNSHP=X])')
    length1, length2 = 0, 0
    for match in pattern.findall(cigar_str):
        num, op = int(match[0]), match[1]
        if op in 'MISX=':  # Consider only operations that consume the query sequence
            length1 += num
        if op in 'MDX=':
            length2 += num
    return length1, length2


def read_cigar_file(file_path):
    """Reads a CIGAR string from a file."""
    with open(file_path, 'r') as file:
        cigar_str = file.read().strip()
    return cigar_str


def parse_cigar_string(cigar_str):
    """Parse CIGAR string to extract operations and their lengths."""
    pattern = re.compile(r'(\d+)([MIDNSHP=X])')
    operations = []
    for match in pattern.findall(cigar_str):
        num, op = int(match[0]), match[1]
        operations.append((num, op))
    return operations


def process_paf_file(file_path):
    """Process PAF file and extract relevant alignment information."""
    ref_positions = []
    query_positions = []
    
    with open(file_path, 'r') as file:
        for line in file:
            fields = line.strip().split('\t')
            ref_start = int(fields[7])  # reference start
            ref_end = int(fields[8])    # reference end
            query_start = int(fields[2])  # query start
            query_end = int(fields[3])    # query end
            
            ref_positions.append(ref_start / 1e6)  # Scale to megabases
            query_positions.append(query_start / 1e6)
            ref_positions.append(ref_end / 1e6)
            query_positions.append(query_end / 1e6)
    
    return ref_positions, query_positions


def visualize_alignments(file1, file2, label, fig_type):
    """Visualize alignments from two files, either CIGAR, SAM, or PAF."""
    fig, ax = plt.subplots(figsize=(10, 8))

    def process_cigar_operations(operations, ref_start=0, query_start=0):
        cx, cy = ref_start, query_start  # Cumulative positions for the CIGAR
        x, y = [cx / 1e6], [cy / 1e6]
        flag = False
        for length, mode in operations:
            if mode in ('M', 'X', '='):
                cx += length
                cy += length
            elif mode == 'I':
                cy += length
            elif mode == 'D':
                cx += length
            elif mode == 'S':
                if not flag:
                    cy += length  # Only adjust query sequence position for 'S'
                    flag = True
            elif mode == 'H':
                continue  # Ignore hard clipping for visualization
            x.append(cx / 1e6)  # Scale by 1e6 for Megabases
            y.append(cy / 1e6)
        return x, y

    def process_paf_file(file_path):
        """Process PAF file and extract relevant alignment information as segments."""
        segments = []
        with open(file_path, 'r') as file:
            for line in file:
                fields = line.strip().split('\t')
                ref_start = int(fields[7])  # reference start
                ref_end = int(fields[8])    # reference end
                query_start = int(fields[2])  # query start
                query_end = int(fields[3])    # query end
                
                # Create a segment (x, y)
                ref_positions = [ref_start / 1e6, ref_end / 1e6]  # Scale to megabases
                query_positions = [query_start / 1e6, query_end / 1e6]
                segments.append((ref_positions, query_positions))
        return segments

    def visualize_file(file_path, label, color, style='-', show_label=True):
        if file_path.endswith('.sam'):
            samfile = pysam.AlignmentFile(file_path, "r")
            for read in samfile.fetch():
                if read.is_secondary or read.is_supplementary:
                    continue  # Skip secondary and supplementary alignments
                ref_start = read.reference_start
                cigar_tuples = read.cigartuples
                if cigar_tuples:
                    operations = parse_cigar(cigar_tuples)
                    x, y = process_cigar_operations(operations, ref_start)
                    ax.plot(x, y, label=label if show_label else "", linewidth=3, color=color, linestyle=style)
        elif file_path.endswith('.paf'):
            # For PAF file, extract segments and plot each individually
            segments = process_paf_file(file_path)
            for ref_positions, query_positions in segments:
                ax.plot(ref_positions, query_positions, label=label if show_label else "", linewidth=3, color=color, linestyle=style)
                show_label = False  # Only show label for the first segment
        else:
            cigar_str = read_cigar_file(file_path)
            operations = parse_cigar_string(cigar_str)
            x, y = process_cigar_operations(operations)
            ax.plot(x, y, label=label if show_label else "", linewidth=3, color=color, linestyle=style)

    # Ensure label only shows once
    visualize_file(file1, 'RaMA', 'blue', '-', show_label=True)
    visualize_file(file2, label, 'red', '--', show_label=True)

    label_font_size = 20
    tick_font_size = 18
    if fig_type == "cen":
        ax.set_xlabel('Centromere of CHM13 for chrX (MB)', fontsize=label_font_size)
        ax.set_ylabel('Centromere of HG002 for chrX (MB)', fontsize=label_font_size)
    else:
        ax.set_xlabel('Template without Region1 (MB)', fontsize=label_font_size)
        ax.set_ylabel('Template without Region2 (MB)', fontsize=label_font_size)
    ax.tick_params(axis='both', which='major', labelsize=tick_font_size)
    ax.tick_params(axis='both', which='minor', labelsize=tick_font_size)
    ax.legend(loc='best', fontsize=20)
    ax.grid(True)
    plt.savefig('./plot/RaMA_'+label+'_'+fig_type+'.pdf')
    plt.show()



    # ax.set_xlabel('Position in Reference (Megabases)', fontsize=16)
    # ax.set_ylabel('Position in Query (Megabases)', fontsize=16)
    # ax.legend(loc='best', fontsize=12)
    # ax.grid(True)
    # plt.title('Alignment Visualization')
    # plt.xlim(left=0)  # Adjust x-axis range if necessary
    # plt.ylim(bottom=0)  # Adjust y-axis range if necessary
    # plt.savefig('alignments.png', dpi=300)
    # plt.show()


# Provide the paths to your SAM or CIGAR files
cen_RaMA = './result/cen_X/RaMA/cigar.txt'
cen_UniAligner = './result/cen_X/UniAligner/cigar.txt'
cen_WFA2 = './result/cen_X/WFA2/cigar.txt'
cen_WFA10 = './result/cen_X/WFA10/cigar.txt'
cen_WFA50 = './result/cen_X/WFA50/cigar.txt'
cen_minimap2 = './result/cen_X/minimap2/output.sam'
cen_wfmash = './result/cen_X/wfmash/result.paf'


tem_RaMA = './result/template/RaMA/cigar.txt'
tem_UniAligner = './result/template/UniAligner/cigar.txt'
tem_WFA2 = './result/template/WFA2/cigar.txt'
tem_WFA10 = './result/template/WFA10/cigar.txt'
tem_WFA50 = './result/template/WFA50/cigar.txt'
tem_minimap2 = './result/template/minimap2/output.sam'
tem_wfmash = './result/template/wfmash/result.paf'


# Visualize the alignments


# visualize_alignments(cen_RaMA, cen_UniAligner, 'UniAligner', "cen")
# visualize_alignments(tem_RaMA, tem_UniAligner, 'UniAligner', "tem")

visualize_alignments(cen_RaMA, cen_WFA2, 'WFA(2)', "cen")
visualize_alignments(tem_RaMA, tem_WFA2, 'WFA(2)', "tem")


# visualize_alignments(cen_RaMA, cen_WFA10, 'WFA(10)', "cen")
# visualize_alignments(tem_RaMA, tem_WFA10, 'WFA(10)', "tem")

# visualize_alignments(cen_RaMA, cen_WFA50, 'WFA(50)', "cen")
# visualize_alignments(tem_RaMA, tem_WFA50, 'WFA(50)', "tem")

# visualize_alignments(cen_RaMA, cen_minimap2, 'minimap2', "cen")
# visualize_alignments(tem_RaMA, tem_minimap2, 'minimap2', "tem")

# visualize_alignments(cen_RaMA, cen_wfmash, 'wfmash', "cen")
# visualize_alignments(tem_RaMA, tem_wfmash, 'wfmash', "tem")