import os
import numpy as np
import csv
# 定义染色体编号
chromosomes = [str(i) for i in range(1, 23)] + ['X']
RaMA_csv_file_ref_chm13 = './csv/CHM13_ref_RaMA_statistics.csv'
RaMA_csv_file_ref_chm1 = './csv/CHM1_ref_RaMA_statistics.csv'
UniAligner_csv_ref_chm13= './csv/CHM13_ref_UniAligner_statistics.csv'
UniAligner_csv_ref_chm1= './csv/CHM1_ref_UniAligner_statistics.csv'

# 定义数据文件路径
chm13_path = "../../data/human_genome/chm13/chm13_cen/"
chm1_path = "../../data/human_genome/chm1/chm1_cen/"
cigar_path_RaMA = "./result/RaMA/"
cigar_path_UniAligner = "./result/UniAligner/"

# 读取FASTA文件的简单函数
def read_fasta(filepath):
    with open(filepath, 'r') as f:
        lines = f.readlines()
    # 去掉FASTA文件中的描述行和换行符
    sequence = ''.join([line.strip() for line in lines if not line.startswith('>')])
    return sequence

# 解析CIGAR字符串的简单函数
def parse_cigar(cigar, ref_sequence, query_sequence):
    # 这是一个简单的CIGAR解析函数，假设只处理常见的M, I, D操作
    aligned_ref = []
    aligned_query = []
    
    ref_idx = 0
    query_idx = 0
    
    i = 0
    while i < len(cigar):
        num = ''
        while cigar[i].isdigit():
            num += cigar[i]
            i += 1
        op = cigar[i]
        i += 1
        
        length = int(num)
        if op in 'MX=':  # 匹配操作
            aligned_ref.append(ref_sequence[ref_idx:ref_idx + length])
            aligned_query.append(query_sequence[query_idx:query_idx + length])
            ref_idx += length
            query_idx += length
        elif op == 'I':  # 插入操作（在查询序列中插入）
            aligned_ref.append('-' * length)
            aligned_query.append(query_sequence[query_idx:query_idx + length])
            query_idx += length
        elif op == 'D':  # 缺失操作（在参考序列中缺失）
            aligned_ref.append(ref_sequence[ref_idx:ref_idx + length])
            aligned_query.append('-' * length)
            ref_idx += length
    
    return ''.join(aligned_ref), ''.join(aligned_query)

# 计算SNV率
def calculate_snv_rate(mismatches, ref_length):
    if ref_length == 0:
        return 0
    return mismatches / ref_length

# 计算 mismatch 的函数
def calculate_mismatches(aligned_ref, aligned_query):
    mismatches = 0
    for r_base, q_base in zip(aligned_ref, aligned_query):
        if r_base != q_base and r_base != '-' and q_base != '-':
            mismatches += 1
    return mismatches

def calculate_identity(aligned_ref, aligned_query): 
    matches = 0
    mismatches = 0
    compressed_gap_count = 0
    in_gap_region = False  # 用于检测是否处于连续gap区域

    # 遍历对齐好的序列
    for r_base, q_base in zip(aligned_ref, aligned_query):
        # 如果遇到gap
        if r_base == '-' or q_base == '-':
            if not in_gap_region:  # 新的gap区域
                compressed_gap_count += 1
                in_gap_region = True  # 标记进入gap区域
        else:
            in_gap_region = False  # 退出gap区域
            if r_base == q_base:  # 碱基匹配
                matches += 1
            else:  # 碱基不匹配
                mismatches += 1
    
    # 根据公式计算一致性
    total_bases = matches + mismatches + compressed_gap_count
    identity = matches / total_bases if total_bases > 0 else 0

    return identity


def calculate_identity_in_reliable_region(aligned_ref, aligned_query, query_length, window_size=10000):
    """计算每个窗口的一致性（Identity），压缩gap区域"""
    identity_list = []
    reliable_ref_bases = 0
    reliable_query_bases = 0
    total_ref_bases = 0
    total_query_bases = 0
    total_mismatch = 0

    ref_len = len(aligned_ref)
    i = 0

    while i < ref_len:
        # 初始化窗口
        filtered_ref = []
        filtered_query = []
        
        # 持续添加直到ref序列非gap碱基达到window_size
        while len(filtered_ref) < window_size and i < ref_len:
            if(aligned_ref[i] == '-'):
                i += 1
                continue
            filtered_ref.append(aligned_ref[i])
            filtered_query.append(aligned_query[i])
            i += 1


        matches = 0
        mismatches = 0
        compressed_gap_count = 0
        in_gap_region = False  # 用于检测是否处于连续gap区域

        for r_base, q_base in zip(filtered_ref, filtered_query):
            if r_base == '-' or q_base == '-':
                if not in_gap_region:  # 遇到新的gap区域
                    compressed_gap_count += 1
                    in_gap_region = True  # 标记进入gap区域
            else:
                in_gap_region = False  # 退出gap区域
                if r_base == q_base:
                    matches += 1
                else:
                    mismatches += 1

        # 根据公式计算一致性
        total_bases = matches + mismatches + compressed_gap_count
        identity = matches / total_bases
        identity_list.append(identity)
            
        # 判断是否可靠比对
        if identity >= 0.9:
            reliable_ref_bases += len(filtered_ref)
            reliable_query_bases += matches+mismatches
            total_mismatch += mismatches

    return identity_list, total_mismatch, reliable_ref_bases, reliable_query_bases


def calculate_statistics(identity_list, reliable_ref_bases, total_ref_bases, reliable_query_bases, total_query_bases, reliable_total_mismatch):   
    # 计算平均一致性和标准差
    avg_identity = np.mean(identity_list)
    std_identity = np.std(identity_list)

    filtered_identity_list = [x for x in identity_list if x > 0]

    non_zero_avg_identity = np.mean(filtered_identity_list)
    non_zero_std_identity = np.std(filtered_identity_list)
    
    # 计算参考序列和查询序列中可靠比对的碱基数占比
    ref_reliable_ratio = reliable_ref_bases / total_ref_bases if total_ref_bases > 0 else 0
    query_reliable_ratio = reliable_query_bases / total_query_bases if total_query_bases > 0 else 0
    if(reliable_ref_bases != 0):
        reliable_SNV = reliable_total_mismatch / reliable_ref_bases
    else:
        reliable_SNV = 0
    
    # return avg_identity, std_identity, ref_reliable_ratio, query_reliable_ratio
    return non_zero_avg_identity, non_zero_std_identity, ref_reliable_ratio, query_reliable_ratio, reliable_SNV


# 主程序续写
for method in ['RaMA', 'UniAligner']:
    for ref_type in ['chm13', 'chm1']:  # 两轮循环，分别以CHM13和CHM1作为ref
        print('*'* 100)
        if method == 'RaMA':
            if ref_type == 'chm13':
                output_csv_file = RaMA_csv_file_ref_chm13
            else:
                output_csv_file = RaMA_csv_file_ref_chm1    
            cigar_path = cigar_path_RaMA
        else:
            if ref_type == 'chm13':
                output_csv_file = UniAligner_csv_ref_chm13
            else:
                output_csv_file = UniAligner_csv_ref_chm1
            cigar_path = cigar_path_UniAligner
        # 如果文件不存在，则写入表头
        if not os.path.exists(output_csv_file):
            with open(output_csv_file, newline='') as file:
                writer = csv.writer(file)
                # 写入表头
                writer.writerow([
                    'chr', 'reference type', 'mean identity (%)','mean identity in reliable region (%)', 'standard deviation in reliable region', 
                    'aligned reference bases', 'ref aligned (%)', 'aligned query bases', 'query aligned (%)', 
                    'ref SNV (%)', 'ref SNV in reliable region (%)'
                ])
        for chrom in chromosomes:
            print('#' * 100)
            print(f"Processing with {ref_type} as reference")
            print(f"Processing chromosome {chrom}")
            # 文件路径
            if ref_type == 'chm13':
                ref_file = os.path.join(chm13_path, f"chr{chrom}_cen.fasta")
                query_file = os.path.join(chm1_path, f"chr{chrom}_cen.fasta")
            else:
                ref_file = os.path.join(chm1_path, f"chr{chrom}_cen.fasta")
                query_file = os.path.join(chm13_path, f"chr{chrom}_cen.fasta")
            
            cigar_file = os.path.join(cigar_path, f"chr{chrom}/cigar.txt")
            
            # 读取两个原始序列
            ref_sequence = read_fasta(ref_file)
            query_sequence = read_fasta(query_file)
            
            # 读取CIGAR比对结果
            with open(cigar_file, 'r') as f:
                cigar_string = f.read().strip()
            
            # 计算未比对前参考序列的长度
            ref_length = len(ref_sequence)
            query_length = len(query_sequence)
            
            # 根据CIGAR字符串对齐序列并计算mismatch数量
            if ref_type == "chm13":
                aligned_ref, aligned_query = parse_cigar(cigar_string, ref_sequence, query_sequence)
            else:
                aligned_ref, aligned_query = parse_cigar(cigar_string, query_sequence, ref_sequence)

            mismatches = calculate_mismatches(aligned_ref, aligned_query)
            
            # 计算SNV率
            snv_rate = calculate_snv_rate(mismatches, ref_length)

            identity = calculate_identity(aligned_ref, aligned_query)
            
            # 计算窗口一致性及统计
            window_size = 10000
            identity_list, reliable_total_mismatch, reliable_ref_bases, reliable_query_bases = calculate_identity_in_reliable_region(aligned_ref, aligned_query, query_length, window_size)
            
            # 计算统计数据
            avg_identity, std_identity, ref_reliable_ratio, query_reliable_ratio, reliable_SNV = calculate_statistics(
                identity_list, reliable_ref_bases, ref_length, reliable_query_bases, query_length, reliable_total_mismatch)
            
            # 输出结果
            print(f"Chromosome {chrom} alignment completed.")
            print(f"Reference length: {ref_length}")
            print(f"Mismatches: {mismatches}")
            print(f"SNV rate %: {snv_rate* 100:.8f}")
            print(f"Average identity: {identity * 100:.8f}")
            print(f"Average identity in reliable region: {avg_identity * 100:.8f}")
            print(f"Standard deviation of identity: {std_identity:.8f}")
            print(f"Reliable reference bases: {reliable_ref_bases} / {ref_length} ({ref_reliable_ratio * 100:.2f}%)")
            print(f"Reliable query bases: {reliable_query_bases} / {query_length} ({query_reliable_ratio * 100:.2f}%)")
            print(f"Reliable SNV rate: {reliable_SNV* 100:.8f}")

                    # 将结果写入CSV文件
            with open(output_csv_file, mode='a', newline='') as file:
                writer = csv.writer(file)
                writer.writerow([
                    chrom, ref_type, f"{identity* 100:.8f}",
                    f"{avg_identity* 100:.8f}", f"{std_identity:.8f}", 
                    reliable_ref_bases, f"{ref_reliable_ratio * 100:.2f}", 
                    reliable_query_bases, f"{query_reliable_ratio * 100:.2f}",
                    f"{snv_rate* 100:.8f}", f"{reliable_SNV* 100:.8f}"
                ])


