import csv
import pysam
import matplotlib.pyplot as plt
import re
# 解析CIGAR字符串
def parse_cigar_string(cigar_str):
    """Parse CIGAR string to extract operations and their lengths."""
    pattern = re.compile(r'(\d+)([MIDNSHP=X])')
    operations = []
    for match in pattern.findall(cigar_str):
        num, op = int(match[0]), match[1]
        operations.append((num, op))
    return operations
def read_cigar_file(file_path):
    """Reads a CIGAR string from a file."""
    with open(file_path, 'r') as file:
        cigar_str = file.read().strip()
    return cigar_str
def get_alignment_end_position(file_path):
    """根据CIGAR文件或PAF文件获取比对结束位置"""
    if file_path.endswith('.sam'):
        # 对于SAM文件，我们通过pysam读取每个比对，获取比对的终止位置
        samfile = pysam.AlignmentFile(file_path, "r")
        end_positions = []
        for read in samfile.fetch():
            if not read.is_secondary and not read.is_supplementary:
                end_positions.append(read.reference_end)  # 获取比对的终止位置
        return max(end_positions)  # 返回最大的终止位置，作为序列末尾
    elif file_path.endswith('.paf'):
        # 对于PAF文件，获取比对结束位置
        with open(file_path, 'r') as file:
            end_positions = []
            for line in file:
                fields = line.strip().split('\t')
                ref_end = int(fields[8])  # 获取参考序列的结束位置
                end_positions.append(ref_end)
            return max(end_positions)  # 返回最大的终止位置
    else:
        # 对于CIGAR文件，通过解析CIGAR字符串计算终止位置
        cigar_str = read_cigar_file(file_path)
        operations = parse_cigar_string(cigar_str)
        ref_pos = 0
        query_pos = 0
        for length, op in operations:
            if op in ('M', 'D', '=', 'X'):  # 这些操作会消耗参考序列
                ref_pos += length
            if op in ('M', 'I', '=', 'X'):
                query_pos += length
        return ref_pos,query_pos  # 返回CIGAR中的比对结束位置
# 读取csv文件，获取每段的位置和长度
def read_position_csv(csv_file):
    positions = []
    with open(csv_file, mode='r') as file:
        reader = csv.DictReader(file)
        for row in reader:
            positions.append({
                'Segment': row['Segment'],
                'Start_Seq1': int(row['Start_Seq1']),
                'Length_Seq1': int(row['Length_Seq1']),
                'Start_Seq2': int(row['Start_Seq2']),
                'Length_Seq2': int(row['Length_Seq2'])
            })
    return positions

# 读取CIGAR文件并解析
def parse_cigar(cigar_tuples):
    operations = []
    for (op, length) in cigar_tuples:
        if op == 0:  # Match or mismatch
            operations.append((length, 'M'))
        elif op == 1:  # Insertion
            operations.append((length, 'I'))
        elif op == 2:  # Deletion
            operations.append((length, 'D'))
        elif op == 4:  # Soft clipping
            operations.append((length, 'S'))
        elif op == 5:  # Hard clipping
            operations.append((length, 'H'))
        elif op == 7:  # Match
            operations.append((length, '='))
        elif op == 8:  # Mismatch
            operations.append((length, 'X'))
    return operations

# 处理CIGAR结果并绘制
def process_cigar_operations(operations, ref_start=0, query_start=0):
    cx, cy = ref_start, query_start  # Cumulative positions for the CIGAR
    x, y = [cx / 1e6], [cy / 1e6]  # Positions for plotting (scaled to megabases)
    
    for length, mode in operations:
        if mode in ('M', 'X', '='):
            cx += length
            cy += length
        elif mode == 'I':
            cy += length
        elif mode == 'D':
            cx += length
        x.append(cx / 1e6)  # Scale by 1e6 for Megabases
        y.append(cy / 1e6)
    return x, y

# 可视化比对路径
def visualize_alignment_with_segments(file1, file2, positions, label):
    fig, ax = plt.subplots(figsize=(10, 8))

    def visualize_file(file_path, label, color, style='-', show_label=True):
        if file_path.endswith('.sam'):
            samfile = pysam.AlignmentFile(file_path, "r")
            for read in samfile.fetch():
                if read.is_secondary or read.is_supplementary:
                    continue  # Skip secondary and supplementary alignments
                ref_start = read.reference_start
                cigar_tuples = read.cigartuples
                if cigar_tuples:
                    operations = parse_cigar(cigar_tuples)
                    x, y = process_cigar_operations(operations, ref_start)
                    ax.plot(x, y, label=label if show_label else "", linewidth=3, color=color, linestyle=style)
        else:
            cigar_str = read_cigar_file(file_path)
            operations = parse_cigar_string(cigar_str)
            x, y = process_cigar_operations(operations)
            ax.plot(x, y, label=label if show_label else "", linewidth=3, color=color, linestyle=style)

    # 读取并显示RaMA的比对
    visualize_file(file1, 'RaMA', 'blue', '-', show_label=True)

    # 读取并显示UniAligner的比对
    visualize_file(file2, 'UniAligner', 'red', '--', show_label=True)

    # 绘制每段的位置范围
    for segment in positions:
        ax.axvline(x=segment['Start_Seq1'] / 1e6, color='gray', linestyle='--')
        ax.axhline(y=segment['Start_Seq2'] / 1e6, color='gray', linestyle='--')
        # ax.text(segment['Start_Seq1'] / 1e6, segment['Start_Seq2'] / 1e6, segment['Segment'], 
                # fontsize=12, color='black', backgroundcolor='white')
    ax.grid(False)

    # 获取RaMA和UniAligner的比对结束位置
    rama_end_position, rama_query_end = get_alignment_end_position(rama_cigar)  # 获取RaMA比对的终止位置
    # unialigner_end_position, unialigner = get_alignment_end_position(unialigner_cigar)  # 获取UniAligner比对的终止位置

# 可视化序列末尾的竖线
    ax.axvline(x=rama_end_position / 1e6, color='gray', linestyle='--')  # 在RaMA序列末尾添加竖线
    ax.axhline(y=rama_query_end / 1e6, color='gray', linestyle='--')  # 在UniAligner序列末尾添加竖线
    # 图形标签与细节
    ax.set_xlabel('Reference Position (MB)', fontsize=16)
    ax.set_ylabel('Query Position (MB)', fontsize=16)
    ax.legend(loc='best', fontsize=14)

    # 添加标题
    plt.title("Comparison of RaMA and UniAligner on Mixed Sequences", fontsize=18)

    
    plt.savefig('./plot/RaMA_UniAligner_with_segments2.pdf', dpi=300)
    plt.show()

# 读取CSV文件
csv_file_path = './csv/position.csv'
positions = read_position_csv(csv_file_path)

# 提供RaMA和UniAligner的CIGAR文件路径
rama_cigar = './result/RaMA2/cigar.txt'
unialigner_cigar = './result/UniAligner2/cigar.txt'

# 可视化
visualize_alignment_with_segments(rama_cigar, unialigner_cigar, positions, 'RaMA_vs_UniAligner')

