import warnings
warnings.filterwarnings("ignore")
from mlxtend.feature_selection import SequentialFeatureSelector as sfs
import joblib
import numpy as np
from sklearn.metrics import accuracy_score
import os
import argparse

from sklearn.naive_bayes import BernoulliNB
from sklearn import svm
from sklearn.ensemble import BaggingClassifier
from sklearn.neighbors import KNeighborsClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier

def extract_test(X_testfilename, y_testfilename, feature_index):
    test_features = np.load(X_testfilename)
    test_labels = np.load(y_testfilename)
    feature_index = np.array(feature_index)
    test_features_new = test_features[:, feature_index]
    return test_features_new, test_labels

parser = argparse.ArgumentParser(description='manual to this script')
parser.add_argument("--WorkSpace", type=str, default="./")
parser.add_argument("--output", type=str, default="./")
parser.add_argument("--trainX", type=str, default="0")
parser.add_argument("--trainy", type=str, default="0")
parser.add_argument("--indexFile", type=str, default="0")
parser.add_argument("--testX", type=str, default="0")
parser.add_argument("--testy", type=str, default="0")
parser.add_argument("--method", type=str, default="0")
parser.add_argument("--filter", type=str, default="0")
parser.add_argument("--k_features", type=int, default=500, help="Number of features to select")
parser.add_argument("--n_jobs", type=int, default=50, help="Number of jobs for parallel processing")
args = parser.parse_args()

# Load important feature indices
with open(os.path.join(args.WorkSpace, args.indexFile), "r") as f:
    important_feature = [int(line.strip("\n")) for line in f.readlines()]

print(args.indexFile, len(important_feature))
features_key = np.array(important_feature)

# Prepare train and test data
train_X, train_y = extract_test(args.trainX, args.trainy, features_key)
test_X, test_y = extract_test(args.testX, args.testy, features_key)

# Define classifiers
dic_m = {
    "LR2": LogisticRegression(solver="liblinear", penalty='l2'),
    "NB": BaggingClassifier(base_estimator=BernoulliNB(), n_estimators=100),
    "knn": KNeighborsClassifier(n_neighbors=3),
    "svmL": svm.SVC(C=1, kernel='linear', probability=True),
    "RF": RandomForestClassifier(criterion="gini", n_estimators=100, max_leaf_nodes=70, max_features=10, min_impurity_decrease=0, min_samples_leaf=4, class_weight='balanced')
}

# Feature selection and model training
def fun1(filter, method, clf, k_features, n_jobs):
    sfs_model = sfs(
        clf,
        k_features=k_features,
        forward=True,
        floating=False,
        verbose=2,
        scoring='accuracy',
        cv=5,
        n_jobs=n_jobs
    )
    sfs_model.fit(train_X, train_y)
    joblib.dump(sfs_model, os.path.join(args.output, "{}_{}_fs.m".format(filter, method)))

    # Get the best score and feature indices for the specified k_features
    max_score = sfs_model.subsets_[k_features]['avg_score']
    filterli = features_key[list(sfs_model.subsets_[k_features]["feature_idx"])]

    # Save feature indices to file
    with open(os.path.join(args.output, "SFS{}_{}index{}.txt".format(filter, method, k_features)), "w") as f1:
        for i in filterli:
            f1.write("{}\n".format(i))

    # Train model and evaluate on test set
    model = clf.fit(train_X[:, list(sfs_model.subsets_[k_features]["feature_idx"])], train_y)
    test_score = accuracy_score(test_y, model.predict(test_X[:, list(sfs_model.subsets_[k_features]["feature_idx"])]))

    # Print results
    print("SFS_model_{}_{}:".format(filter, method), max_score, test_score, k_features)

# Run feature selection and evaluation
fun1(args.filter, args.method, dic_m[args.method], args.k_features, args.n_jobs)