from train import Trainer
import argparse
import random
import numpy as np
from pprint import pprint
import torch
import os


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    # model definition 
    parser.add_argument("--species", default='human', type=str)
    parser.add_argument("--tissue", default='PBMC', type=str)
    parser.add_argument("--gpu", type=int, default=0,
                        help="GPU id, -1 for cpu")
    parser.add_argument("--n_epochs", type=int, default=1000,
                        help="maximum of training epochs")
    parser.add_argument("--dense_dim", type=int, default=400,
                        help="number of PCA units")
    parser.add_argument("--hidden_dim", type=int, default=200,
                        help="number of hidden GNN units")
    parser.add_argument("--n_layers", type=int, default=2,
                        help="number of hidden GNN layers")
    parser.add_argument("--random_seed", type=int, default=10086)
    parser.add_argument("--dropout", type=float, default=0.1,
                        help="dropout probability")
    parser.add_argument("--lr", type=float, default=0.005,
                        help="learning rate")
    parser.add_argument("--weight_decay", type=float, default=5e-4,
                        help="Weight for L2 loss")
    parser.add_argument("--threshold", type=float, default=0,
                        help="the threshold to connect edges between cells and genes")
    parser.add_argument("--num_neighbors", type=int, default=0,
                        help="number of neighbors to sample in message passing process. 0 means all neighbors")
    parser.add_argument("--exclude_rate", type=float, default=0.005,
                        help="exclude some cells less than this rate")
    parser.add_argument("--batch_size", type=int, default=1000)
    parser.add_argument("--test_rate", type=float, default=0.2)
    parser.add_argument("--patience", type=int, default=50,
                        help="the window size of early stopping")

    # input and output setting 
    parser.add_argument("--files", type=list, default=[('datasets/pbmc2_10X_v2_data.csv',''),
                                                       ('datasets/pbmc1_10X_v2_data.csv','datasets/pbmc1_10X_v2_label.csv'),
                                                       ('datasets/pbmc1_10X_v3_data.csv','datasets/pbmc1_10X_v3_label.csv')],
                        help = "the file list of query expression profile, reference expression profile and reference label, note that the query file should be placed first.")
    parser.add_argument("--grn_file", type=str, default='statistics/regulations.txt',
                        help = "the gene regulatory network file")
    parser.add_argument("--emb_file", type=str, default='statistics/genenodes.npy',
                        help = "the embedding file of gene which is generated by pre-process.py")
    parser.add_argument("--order_file", type=str, default='statistics/gene_statistics.csv',
                        help = "the order file of genes sorted by dispersion descending order which is generated by pre-process.py")  
    parser.add_argument("--topk", type=int, default=5000,
                        help = "the number of genes filtered based on order_file, if topk < 0 then all genes in files are included")
    parser.add_argument("--reads", type=str, default='log',
                        help = "data type of scRNA-seq data, log (scHGR do not perform log-nomalization) or reads (scHGR perform log-nomalization)")
    parser.add_argument("--oritation", type=str, default='gc',
                        help = "the orientation of scRNA-seq data, gc (rows represent genes, columns represent cells) or cg (rows represent cells, columns represent genes)")
    parser.add_argument("--save_path", type=str, default='output/',
                        help = "the save path of models and outputs")
    params = parser.parse_args()

    random.seed(params.random_seed)
    np.random.seed(params.random_seed)
    torch.manual_seed(params.random_seed)
    torch.cuda.manual_seed(params.random_seed)
    pprint(vars(params))

    # training model
    trainer = Trainer(params)
    print('Start to fit.')
    trainer.fit()

    # annotating query data
    # predictions = trainer.predict()
    # print(predictions)

    # annotating and evaluating query data
    gold_file = 'datasets/pbmc2_10X_v2_label.csv'
    predictions = trainer.predict(gold_file)



