from sklearn.metrics import recall_score, accuracy_score, roc_auc_score, f1_score, confusion_matrix, precision_score
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LogisticRegression
from sklearn import tree
from sklearn import svm
from sklearn.preprocessing import label_binarize
from imblearn.over_sampling import SMOTE
import joblib
import numpy as np

def assess_binary(y_true, y_pred, y_prob):
    """Evaluate binary classification performance."""
    return {
        "recall": recall_score(y_true, y_pred, average='weighted'),
        "accuracy": accuracy_score(y_true, y_pred),
        "precision": precision_score(y_true, y_pred, average='weighted'),
        "f1": f1_score(y_true, y_pred, average='macro'),
        "roc_auc": roc_auc_score(y_true, y_prob[:, 1]),
        "confusion_matrix": confusion_matrix(y_true, y_pred)
    }

def prob_dict2(prob_list):
    """Extract breed and probability results from multiple simulations."""
    dic_breed, dic_prob = {}, {}

    for k in prob_list[0].keys():
        breed, prob = [], []

        for item in prob_list:
            indi = item[k][1]
            breed_temp = np.zeros(indi.shape)
            prob_temp = np.zeros(indi.shape)

            for i in range(indi.shape[0]):
                index = np.argsort(-indi[i, :])
                prob_temp[i, :] = indi[i, index]
                breed_temp[i, :] = np.arange(indi.shape[1])[index]

            breed.append(breed_temp)
            prob.append(prob_temp)

        dic_breed[k] = np.concatenate(breed, axis=0)
        dic_prob[k] = np.concatenate(prob, axis=0)

    return dic_breed, dic_prob

def prob_dict(prob):
    """Extract breed and probability results from a single dataset."""
    dic_breed, dic_prob = {}, {}

    for k, indi in prob.items():
        breed = np.zeros(indi.shape)
        prob = np.zeros(indi.shape)

        for i in range(indi.shape[0]):
            index = np.argsort(-indi[i, :])
            prob[i, :] = indi[i, index]
            breed[i, :] = np.arange(indi.shape[1])[index]

        dic_breed[k] = breed
        dic_prob[k] = prob

    return dic_breed, dic_prob

def load_hybrid_data(sim_count=5):
    """Load hybrid population simulation results."""
    hybrid_data = [
        np.load(f"hybrid_sim{i}_prob.npy", allow_pickle=True).item()
        for i in range(1, sim_count + 1)
    ]
    return prob_dict2(hybrid_data)

def load_pure_population_data():
    """Load pure population results."""
    pred_all = np.load("pure_population_prob.npy", allow_pickle=True).item()
    prob_all = {k: v[1] for k, v in pred_all.items()}
    return prob_dict(prob_all)

def prepare_data(hybrid_simulations, models):
    """Prepare combined data matrix X and labels y."""
    X_list = []

    for model in models:
        hybrid_probs = [hybrid_simulations[f"dic_probs{sim}"][model][:, :6] for sim in range(5, 15)]
        pure_probs = pure_population[model][:, :6]
        X_list.append(np.concatenate([pure_probs] + hybrid_probs, axis=0))

    X = np.concatenate(X_list, axis=1)

    y = np.concatenate([
        np.zeros(pure_population["LogisticRegression_l2"].shape[0]),
        np.ones(hybrid_simulations["LogisticRegression_l2"].shape[0] * 6)
    ])

    return X, y

def train_and_evaluate(model, X_train, y_train, X_test, y_test, model_name):
    """Train model, save it, and evaluate."""
    model.fit(X_train, y_train)
    joblib.dump(model, f"{model_name}.m")

    y_pred = model.predict(X_test)
    y_prob = model.predict_proba(X_test)

    metrics = assess_binary(y_test, y_pred, y_prob)
    return metrics

# ===== Load and Process Data =====
hybrid_simulations = load_hybrid_data()
pure_population = load_pure_population_data()

# Prepare data
models = ["RandomForest_gini_100_10maxFea", "LogisticRegression_l2", "SVM(linear)"]
X, y = prepare_data(hybrid_simulations, models)

# Split and resample
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, stratify=y, random_state=4)

smote = SMOTE(random_state=20)
X_resampled, y_resampled = smote.fit_resample(X_train, y_train)

# ===== Model Training and Evaluation =====

# SVM
svm_model = svm.SVC(C=10, kernel='rbf', probability=True)
svm_metrics = train_and_evaluate(svm_model, X_resampled, y_resampled, X_test, y_test, "svm_model")

# Decision Tree
tree_model = tree.DecisionTreeClassifier(criterion="gini", max_depth=3, min_samples_leaf=5, min_samples_split=5)
tree_metrics = train_and_evaluate(tree_model, X_resampled, y_resampled, X_test, y_test, "tree_model")

# Logistic Regression
lr_model = LogisticRegression(penalty="l1", C=0.5, solver="liblinear")
lr_metrics = train_and_evaluate(lr_model, X_resampled, y_resampled, X_test, y_test, "lr_model")

# Random Forest
rf_model = RandomForestClassifier(criterion="gini", random_state=30, max_depth=25,
                                  min_samples_leaf=5, min_samples_split=5)
rf_metrics = train_and_evaluate(rf_model, X_resampled, y_resampled, X_test, y_test, "rf_model")



# ===== Output Metrics =====
print("\nModel Performance Metrics:")
print(f"SVM: {svm_metrics}")
print(f"Decision Tree: {tree_metrics}")
print(f"Logistic Regression: {lr_metrics}")
print(f"Random Forest: {rf_metrics}")
print(f"Tree4Judge: {tree4judge_metrics}")
