from Bio import SeqIO
import os
import pandas as pd
import re

# 定义数据路径
chm13_path = "../../data/human_genome/chm13/chm13_cen/"
chm1_path = "../../data/human_genome/chm1/chm1_cen/"
cigar_file_path = "./result/UniAligner/cigar.txt"
tsv_file_path = "./result/UniAligner/shortest_matches.tsv"

# 定义统计指标字典
total_stats = {
    "chromosome": [],
    "total length": [],
    "reliable length": [],
    "rare match length": [],
    "reliable ratio (%)": [],
    "rare match ratio (%)": [],
    "large indel": [],
    "large indel count": [],
    "reliable large indel": [],
    "reliable large indel count": [],
    "large indel ratio (%)": [],
    "reliable large indel ratio (%)": []
}

# 读取FASTA文件中的原始序列
def load_fasta_sequence(fasta_path):
    with open(fasta_path, "r") as file:
        for record in SeqIO.parse(file, "fasta"):
            return str(record.seq)
        
# 解析CIGAR字符串并生成比对后的序列
def align_sequences_using_cigar(ref_seq, query_seq, cigar_str):
    ref_aligned = []
    query_aligned = []
    ref_pos, query_pos = 0, 0
    
    # 解析CIGAR字符串
    matches = re.findall(r'(\d+)([MIDNSHP=X])', cigar_str)
    for length, op in matches:
        length = int(length)
        if op == "M" or op == "=" or op == "X":  # 匹配或不匹配
            ref_aligned.append(ref_seq[ref_pos:ref_pos + length])
            query_aligned.append(query_seq[query_pos:query_pos + length])
            ref_pos += length
            query_pos += length
        elif op == "I":  # 插入
            ref_aligned.append("-" * length)
            query_aligned.append(query_seq[query_pos:query_pos + length])
            query_pos += length
        elif op == "D":  # 删除
            ref_aligned.append(ref_seq[ref_pos:ref_pos + length])
            query_aligned.append("-" * length)
            ref_pos += length
    
    # 返回比对后的序列
    return "".join(ref_aligned), "".join(query_aligned)

# 生成区间的CIGAR表达式，合并连续的相同操作
def generate_cigar_for_interval(ref_segment, query_segment):
    cigar = []
    match_count = 0
    last_op = None
    last_op_len = 0

    for r, q in zip(ref_segment, query_segment):
        if r == q and r != "-":  # 完美匹配
            match_count += 1
        else:
            # 将累积的匹配计数添加到CIGAR表达式
            if match_count > 0:
                if last_op == "M":
                    last_op_len += match_count
                else:
                    if last_op is not None:
                        cigar.append(f"{last_op_len}{last_op}")
                    last_op, last_op_len = "M", match_count
                match_count = 0

            # 处理当前的插入、删除或替换操作
            if r == "-":  # 插入
                if last_op == "I":
                    last_op_len += 1
                else:
                    if last_op is not None:
                        cigar.append(f"{last_op_len}{last_op}")
                    last_op, last_op_len = "I", 1
            elif q == "-":  # 删除
                if last_op == "D":
                    last_op_len += 1
                else:
                    if last_op is not None:
                        cigar.append(f"{last_op_len}{last_op}")
                    last_op, last_op_len = "D", 1
            else:  # 替换
                if last_op == "M":
                    last_op_len += 1
                else:
                    if last_op is not None:
                        cigar.append(f"{last_op_len}{last_op}")
                    last_op, last_op_len = "M", 1

    # 将最后一个操作添加到CIGAR表达式
    if match_count > 0:
        if last_op == "M":
            last_op_len += match_count
        else:
            if last_op is not None:
                cigar.append(f"{last_op_len}{last_op}")
            last_op, last_op_len = "M", match_count
    if last_op is not None:
        cigar.append(f"{last_op_len}{last_op}")

    return "".join(cigar)


# 定义统计指标字典
total_stats = {
    "chromosome": [],
    "total length": [],
    "reliable length": [],
    "rare match length": [],
    "reliable ratio (%)": [],
    "rare match ratio (%)": [],
    "large indel": [],
    "large indel count": [],
    "reliable large indel": [],
    "reliable large indel count": [],
    "large indel ratio (%)": [],
    "reliable large indel ratio (%)": []
}

def generate_cigar_for_interval(ref_segment, query_segment):
    cigar = []
    last_op = None
    last_op_len = 0

    for r, q in zip(ref_segment, query_segment):
        if r == q and r != "-":  # 完美匹配
            if last_op == "=":
                last_op_len += 1
            else:
                if last_op is not None:
                    cigar.append(f"{last_op_len}{last_op}")
                last_op, last_op_len = "=", 1
        else:
            # 将累积的匹配计数添加到CIGAR表达式
            if r != "-" and q != "-":  # 替换
                if last_op == "X":
                    last_op_len += 1
                else:
                    if last_op is not None:
                        cigar.append(f"{last_op_len}{last_op}")
                    last_op, last_op_len = "X", 1
            elif r == "-":  # 插入
                if last_op == "I":
                    last_op_len += 1
                else:
                    if last_op is not None:
                        cigar.append(f"{last_op_len}{last_op}")
                    last_op, last_op_len = "I", 1
            elif q == "-":  # 删除
                if last_op == "D":
                    last_op_len += 1
                else:
                    if last_op is not None:
                        cigar.append(f"{last_op_len}{last_op}")
                    last_op, last_op_len = "D", 1

    # 将最后一个操作添加到CIGAR表达式
    if last_op is not None:
        cigar.append(f"{last_op_len}{last_op}")

    return "".join(cigar)

# 切割和统计指标
def segment_and_cigarize_by_rare_match(ref_aligned, query_aligned, rare_match_positions, chromosome):
    total_length = len(ref_aligned)
    reliable_length = 0
    rare_match_length = sum([pos[2] for pos in rare_match_positions])  # 全部rare match长度
    large_indel = 0
    large_indel_count = 0
    reliable_large_indel = 0
    reliable_large_indel_count = 0

    segment_cigars = []
    prev_ref_end = 0
    prev_query_end = 0

    for ref_start, query_start, length in rare_match_positions:
        ref_end = ref_start + length
        query_end = query_start + length
        ref_segment = ref_aligned[ref_start:ref_start + length]
        query_segment = query_aligned[query_start:query_start + length]
        print(ref_segment == query_segment)

        # 截取 rare match 之间的区间
        ref_segment = ref_aligned[prev_ref_end:ref_end]
        query_segment = query_aligned[prev_query_end:query_end]
        
        # 生成该区间的CIGAR表达式
        cigar = generate_cigar_for_interval(ref_segment, query_segment)
        large_indel_count += cigar.count("I") + cigar.count("D")
        large_indel += sum(int(length) for length, op in re.findall(r'(\d+)([ID])', cigar))
        # 判断是否为可靠区域
        if len(cigar) == 1 and 'X' not in cigar[0]:
            reliable_length += len(ref_segment)
            # 统计reliable region中的插入或删除
            reliable_large_indel += sum(int(length) for length, op in re.findall(r'(\d+)([ID])', cigar))
            reliable_large_indel_count += 1
        
        # 更新上一个区间的结束位置
        prev_ref_end = ref_end
        prev_query_end = query_end

    # 最后一个rare match到序列结尾的区间
    if prev_ref_end < len(ref_aligned) and prev_query_end < len(query_aligned):
        ref_segment = ref_aligned[prev_ref_end:]
        query_segment = query_aligned[prev_query_end:]
        cigar = generate_cigar_for_interval(ref_segment, query_segment)
        segment_cigars.append({
            "interval": ((prev_ref_end, len(ref_aligned)), (prev_query_end, len(query_aligned))),
            "cigar": cigar
        })

    # 计算比例
    reliable_ratio = (reliable_length / total_length) * 100 if total_length > 0 else 0
    rare_match_ratio = (rare_match_length / total_length) * 100 if total_length > 0 else 0
    large_indel_ratio = (large_indel / total_length) * 100 if total_length > 0 else 0
    reliable_large_indel_ratio = (reliable_large_indel / total_length) * 100 if total_length > 0 else 0

    # 存储结果
    total_stats["chromosome"].append(chromosome)
    total_stats["total length"].append(total_length)
    total_stats["reliable length"].append(reliable_length)
    total_stats["rare match length"].append(rare_match_length)
    total_stats["reliable ratio (%)"].append(reliable_ratio)
    total_stats["rare match ratio (%)"].append(rare_match_ratio)
    total_stats["large indel"].append(large_indel)
    total_stats["large indel count"].append(large_indel_count)
    total_stats["reliable large indel"].append(reliable_large_indel)
    total_stats["reliable large indel count"].append(reliable_large_indel_count)
    total_stats["large indel ratio (%)"].append(large_indel_ratio)
    total_stats["reliable large indel ratio (%)"].append(reliable_large_indel_ratio)


chrs = [i for i in range(1, 23)] + ['X']

for chr in chrs:

    # 加载CHM13和CHM1的原始序列
    chm13_seq = load_fasta_sequence(chm13_path + f"chr{chr}_cen.fasta")
    chm1_seq = load_fasta_sequence(chm1_path + f"chr{chr}_cen.fasta")
    cigar_file_path = f"./result/UniAligner/chr{chr}/cigar.txt"
    tsv_file_path = f"./result/UniAligner/chr{chr}/shortest_matches.tsv"


    # 读取并解析CIGAR文件
    if os.path.exists(cigar_file_path):
        with open(cigar_file_path, 'r') as cigar_file:
            cigar_str = cigar_file.read().strip()  # 假设文件只有一个CIGAR字符串
            ref_aligned, query_aligned = align_sequences_using_cigar(chm13_seq, chm1_seq, cigar_str)
            print(len(ref_aligned), len(query_aligned))
           

    # 读取TSV文件，用于rare match和reliable region的标记
    if os.path.exists(tsv_file_path):
        df_tsv = pd.read_csv(tsv_file_path, sep='\t')
        df_tsv = df_tsv.sort_values(by="FstStarts")
        rare_match_positions = [(row['FstStarts'], row['SndStarts'], row['Length']) for _, row in df_tsv.iterrows()]
         # 切割比对好的序列并解析区间的CIGAR表达式
        segment_cigars = segment_and_cigarize_by_rare_match(ref_aligned, query_aligned, rare_match_positions, chr)

# 将 total_stats 转换为 DataFrame
df_stats = pd.DataFrame(total_stats)

# 定义输出路径
output_csv_path = "./csv/UniAligner_rare_match_reliablity.csv"

# 创建保存目录（如果不存在）
os.makedirs(os.path.dirname(output_csv_path), exist_ok=True)

# 保存为 CSV 文件
df_stats.to_csv(output_csv_path, index=False)
