import pandas as pd
import numpy as np
import os
import multiprocessing
import warnings

warnings.filterwarnings("ignore")


def meanFst(n, inputpath, outputpath):
    os.chdir(inputpath)
    li = os.listdir()
    Fstreslt = np.zeros([n, n], dtype=float)
    filenumber = np.zeros([n, n], dtype=int)
    locinumber = np.zeros([n, n], dtype=int)

    for i in range(1, n + 1):
        for j in range(i + 1, n + 1):
            file = f"P{j}vsP{i}.Fst.weir.fst"
            row, col = i - 1, j - 1
            if file in li:
                df = pd.read_csv(file, sep="\t").dropna()
                filenumber[row, col] = df.shape[0]
                Fstreslt[row, col] = np.mean(df["WEIR_AND_COCKERHAM_FST"])
                locinumber[row, col] = df.shape[0]
            else:
                Fstreslt[row, col] = np.nan
                filenumber[row, col] = np.nan
                locinumber[row, col] = np.nan

    os.chdir(outputpath)
    pd.DataFrame(Fstreslt).to_csv(f"Fstreslt{n}.csv", header=None, index=None)
    pd.DataFrame(locinumber).to_csv(f"locinumber{n}.csv", header=None, index=None)
    pd.DataFrame(filenumber).to_csv(f"filenumber{n}.csv", header=None, index=None)


def meantheta(n, inputpath, outputpath):
    os.chdir(inputpath)
    li = os.listdir()
    thetaReslt = np.zeros([n, n], dtype=float)
    fileLoci = np.zeros([n, 1], dtype=int)
    thetaLoci = np.zeros([n, n], dtype=int)

    for i in range(1, n + 1):
        row = i - 1
        file1 = f"P{i}.Pi.sites.pi"
        if file1 in li:
            df1 = pd.read_csv(file1, sep="\t")
            fileLoci[row] = df1.shape[0]
            for j in range(i + 1, n + 1):
                col = j - 1
                file2 = f"P{j}.Pi.sites.pi"
                if file2 in li:
                    df2 = pd.read_csv(file2, sep="\t")
                    np3 = np.log2(df1["PI"] / df2["PI"])
                    np3 = np3[(np.abs(np3) != np.inf) & np3.notna()]
                    thetaReslt[row, col] = np.mean(np3)
                    thetaLoci[row, col] = np3.shape[0]
        else:
            fileLoci[row] = -99999

    os.chdir(outputpath)
    pd.DataFrame(thetaReslt).to_csv(f"thetaReslt{n}.csv", header=None, index=None)
    pd.DataFrame(thetaLoci).to_csv(f"thetaLoci{n}.csv", header=None, index=None)
    pd.DataFrame(fileLoci).to_csv(f"thetaLociNum{n}.csv", header=None, index=None)


def selective_window(i, j, bin, inputpath1, inputpath2, outputpath):
    file1 = f"P{i}_{int(bin/1000)}kb.Pi.windowed.pi"
    df1 = pd.read_csv(os.path.join(inputpath1, file1), sep="\t")

    file2 = f"P{j}_{int(bin/1000)}kb.Pi.windowed.pi"
    df2 = pd.read_csv(os.path.join(inputpath1, file2), sep="\t")

    df3 = pd.merge(df1, df2, how="outer", on=["CHROM", "BIN_START", "BIN_END"])
    df3["Pitheta"] = np.log2(df3["PI_x"] / df3["PI_y"])

    file = f"P{j}vsP{i}.Fst.weir.fst"
    with open(os.path.join(inputpath2, file), "r") as f:
        window_SNP = []
        df4 = pd.DataFrame(columns=["CHROM", "BIN_START", "FstVariant", "meanFst"])
        bin_start, bin_len, chrnum = 1, bin, 1
        f.readline()
        for line in f:
            SNP = line.strip().split("\t")
            if SNP[2] not in ["nan", "-nan"]:
                if chrnum == int(SNP[0]):
                    if int(SNP[1]) < (bin_len + bin_start):
                        window_SNP.append(float(SNP[2]))
                    else:
                        df4 = df4.append({
                            "CHROM": chrnum, 'BIN_START': bin_start, 'FstVariant': len(window_SNP),
                            "meanFst": np.mean(window_SNP)
                        }, ignore_index=True)
                        bin_start += bin_len
                        window_SNP = [float(SNP[2])]
                else:
                    chrnum = int(SNP[0])
                    bin_start = 1
                    window_SNP = [float(SNP[2])]

    pd.merge(df3, df4, how="outer", on=["CHROM", "BIN_START"]).to_csv(
        os.path.join(outputpath, f"P{i}vsP{j}_{bin/1000}kb.csv"), index=None)


def get_window(critetion1, c2, c3, SS):
    tail1_right = np.percentile(SS["meanFst"], critetion1)
    tail2_right = np.percentile(SS["Pitheta"], c2)
    tail2_left = np.percentile(SS["Pitheta"], 100 - c3)

    pi_l = SS[SS["Pitheta"] < tail2_left].index
    pi_r = SS[SS["Pitheta"] > tail2_right].index
    Fst_r = SS[SS["meanFst"] > tail1_right].index

    SS["selective region"] = "None"
    SS.loc[Fst_r & pi_r, "selective region"] = "Pi>Pj"
    SS.loc[Fst_r & pi_l, "selective region"] = "Pi<Pj"

    return SS


def result_window(inputpath, i, j, bin_size):
    name = f"SR_P{i}vsP{j}"
    SS = pd.read_csv(os.path.join(inputpath, f"P{i}vsP{j}_{bin_size}.0kb.csv")).dropna()

    c1, c2, c3 = 99, 99, 99
    SS = get_window(c1, c2, c3, SS)

    while True:
        n1 = SS[SS["selective region"] == "Pi>Pj"].shape[0]
        n2 = SS[SS["selective region"] == "Pi<Pj"].shape[0]

        if max(c1, c2, c3) >= 100:
            c1, c2, c3 = [min(x, 99.9) for x in [c1, c2, c3]]
            SS = get_window(c1, c2, c3, SS)
            break
        else:
            if n1 < 10 and n2 < 10:
                c1 -= (10 - max(n1, n2)) / 10
            elif n1 < 10:
                c2 -= (10 - n1) / 10
            elif n2 < 10:
                c3 -= (10 - n2) / 10
            else:
                break

        SS = get_window(c1, c2, c3, SS)

    result = SS.loc[SS["selective region"] != "None", ["CHROM", "BIN_START", "BIN_END", "selective region"]]
    result = result.rename({"selective region": name}, axis=1)
    return result


# Example usage:
n = 48
inputpath = "/path/to/input"
outputpath = "/path/to/output"

meanFst(n, inputpath, outputpath)
meantheta(n, inputpath, outputpath)

result_sum = pd.DataFrame()

for i in range(1, n + 1):
    for j in range(i + 1, n + 1):
        result = result_window(inputpath, i, j, 100)
        result_sum = result if result_sum.empty else pd.merge(result_sum, result, how="outer", on=["CHROM", "BIN_START", "BIN_END"])

result_sum["selective number"] = result_sum.iloc[:, 3:].notna().sum(axis=1)
result_sum.to_csv(os.path.join(outputpath, "result_sum_100kb.csv"), index=None)
