import os
import re
import pandas as pd

def extract_metrics_from_log(log_file):
    metrics = {
        "suffix_array_time": 0,
        "suffix_array_memory": 0,
        "sparse_table_time": 0,
        "sparse_table_memory": 0,
        "anchor_search_time": 0,
        "anchor_search_memory": 0,
        "alignment_time": 0,
        "alignment_memory": 0,
        "rare_alignment_time": 0,
        "rare_alignment_memory": 0,
        "total_time": 0,
        "total_memory": 0
    }

    suffix_array_start_time = None
    sparse_table_start_time = None
    anchor_search_start_time = None
    alignment_start_time = None

    last_memory_usage = 0  # 上一个阶段结束时的内存使用

    def convert_memory_to_gb(memory_str):
        """将内存值转换为GB"""
        memory_value = float(memory_str[:-3])  # 提取数值部分
        unit = memory_str[-2:]  # 提取单位 (MB 或 GB)
        if unit == "MB":
            return memory_value / 1024  # MB转GB
        elif unit == "GB":
            return memory_value  # GB保持不变
        return 0

    with open(log_file, "r") as f:
        lines = f.readlines()
        for line in lines:
            # 提取后缀数组构建开始和结束时间
            if "The suffix array is constructing..." in line:
                suffix_array_start_time = float(re.search(r"\d+\.\d+", line).group(0))
            if "The suffix array construction is finished" in line:
                suffix_array_end_time = float(re.search(r"\d+\.\d+", line).group(0))
                metrics["suffix_array_time"] = suffix_array_end_time - suffix_array_start_time
                memory_match = re.search(r"(\d+\.\d+ [MG]B)", line)
                if memory_match:
                    current_memory = convert_memory_to_gb(memory_match.group(1))
                    metrics["suffix_array_memory"] = current_memory - last_memory_usage
                    last_memory_usage = current_memory  # 更新上一个阶段的内存使用

            # 提取稀疏表构建时间
            if "The sparse table is constructing..." in line:
                sparse_table_start_time = float(re.search(r"\d+\.\d+", line).group(0))
            if "The sparse table construction is finished" in line:
                sparse_table_end_time = float(re.search(r"\d+\.\d+", line).group(0))
                metrics["sparse_table_time"] = sparse_table_end_time - sparse_table_start_time
                memory_match = re.search(r"(\d+\.\d+ [MG]B)", line)
                if memory_match:
                    current_memory = convert_memory_to_gb(memory_match.group(1))
                    metrics["sparse_table_memory"] = current_memory - last_memory_usage
                    last_memory_usage = current_memory  # 更新上一个阶段的内存使用

            # 提取锚点搜索时间
            if "Begin to search anchors" in line:
                anchor_search_start_time = float(re.search(r"\d+\.\d+", line).group(0))
            if "Finish searching anchors" in line:
                anchor_search_end_time = float(re.search(r"\d+\.\d+", line).group(0))
                metrics["anchor_search_time"] = anchor_search_end_time - anchor_search_start_time
                memory_match = re.search(r"(\d+\.\d+ [MG]B)", line)
                if memory_match:
                    current_memory = convert_memory_to_gb(memory_match.group(1))
                    metrics["anchor_search_memory"] = current_memory - last_memory_usage
                    last_memory_usage = current_memory  # 更新上一个阶段的内存使用

            # 提取比对时间
            if "Begin to align intervals using wavefront alignment method" in line:
                alignment_start_time = float(re.search(r"\d+\.\d+", line).group(0))
            if "Wavefront alignment of intervals has been completed" in line:
                alignment_end_time = float(re.search(r"\d+\.\d+", line).group(0))
                metrics["alignment_time"] = alignment_end_time - alignment_start_time
                memory_match = re.search(r"(\d+\.\d+ [MG]B)", line)
                if memory_match:
                    current_memory = convert_memory_to_gb(memory_match.group(1))
                    metrics["alignment_memory"] = current_memory - last_memory_usage
                    last_memory_usage = current_memory  # 更新上一个阶段的内存使用

            # 提取稀有比对（rare alignment）的时间
            if "INFO: Begin to search anchors" in line:
                rare_alignment_start_time = float(re.search(r"\d+\.\d+", line).group(0))
                metrics["rare_alignment_time"] = rare_alignment_start_time
            if "Finish searching anchors" in line:
                rare_alignment_end_time = float(re.search(r"\d+\.\d+", line).group(0))
                memory_match = re.search(r"(\d+\.\d+ [MG]B)", line)
                if memory_match:
                    current_memory = convert_memory_to_gb(memory_match.group(1))
                    metrics["rare_alignment_memory"] = current_memory

            # 提取总内存和时间
            if "Max memory used" in line:
                memory_match = re.search(r"MaxMemory:\s+(\d+\.\d+ [MG]B)", line)
                if memory_match:
                    metrics["total_memory"] = convert_memory_to_gb(memory_match.group(1))
            if "End RaMA!" in line:
                total_end_time = float(re.search(r"\d+\.\d+", line).group(0))
                if total_end_time:
                    metrics["total_time"] = total_end_time

    return metrics




# 统计多次运行的平均值
def calculate_averages(log_dir, output_csv):
    chromosomes = [f"chr{i}_cen.fasta" for i in range(1, 23)] + ["chrX_cen.fasta"]
    chromosome_mapping = {f"chr{i}_cen.fasta": f"chr{i}" for i in range(1, 23)}
    chromosome_mapping["chrX_cen.fasta"] = "chrX"

    result_data = []

    for chromosome in chromosomes:
        chromosome_name = chromosome_mapping[chromosome]
        metrics_list = []

        for run in range(1, 11):  # 10次实验
            log_file = os.path.join(log_dir, f"run_{run}", chromosome_name, "RaMA.log")
            if os.path.exists(log_file):
                metrics = extract_metrics_from_log(log_file)
                metrics_list.append(metrics)

        if metrics_list:
            # 计算平均值
            avg_metrics = {key: sum([m[key] for m in metrics_list]) / len(metrics_list) for key in metrics_list[0]}
            avg_metrics["chromosome"] = chromosome_name
            result_data.append(avg_metrics)

    # 将结果写入CSV文件
    df = pd.DataFrame(result_data)
    output_folder = os.path.dirname(output_csv)
    if not os.path.exists(output_folder):
        os.makedirs(output_folder)
    df.to_csv(output_csv, index=False)
    print(f"Results written to {output_csv}")

# 运行代码
log_dir = "/mnt/d/Result/RaMA/exp/new_chm13_chm1/result/RaMA"
output_csv = "/mnt/d/Result/RaMA/exp/new_chm13_chm1/csv/RaMA_performance.csv"

calculate_averages(log_dir, output_csv)
