#!/bin/python3
"""
description: 

example:
    python3 triotrain/summarize/summary.py                                     \\
        --metadata ../TRIO_TRAINING_OUTPUTS/final_results/inputs/240329_summary_metrics.csv    \\
        --output ../TRIO_TRAINING_OUTPUTS/final_results/240402_sample_stats.csv        \\
        -r triotrain/model_training/tutorial/resources_used.json  \\
        --dry-run
"""

import argparse
from csv import DictReader
from dataclasses import dataclass, field
from json import load
from logging import Logger
from pathlib import Path
from sys import path
from typing import List, Union

abs_path = Path(__file__).resolve()
module_path = str(abs_path.parent.parent)
path.append(module_path)

from helpers.files import TestFile, Files
from helpers.iteration import Iteration
from helpers.utils import check_if_all_same, generate_job_id
from model_training.slurm.sbatch import SBATCH, SubmitSBATCH
from pantry import preserve
from results import SummarizeResults


@dataclass
class Summary:
    """
    Define what data to keep when processing a VCF generated by DeepVariant.
    """

    # Required parameters
    args: argparse.Namespace
    logger: Logger

    # Imutable, internal parameters
    _command_list: List[str] = field(default_factory=list, init=False, repr=False)
    _job_nums: List = field(default_factory=list, init=False, repr=False)
    _num_processed: int = field(default=0, init=False, repr=False)
    _num_skipped: int = field(default=0, init=False, repr=False)
    _num_submitted: int = field(default=0, init=False, repr=False)
    _phase: str = field(default="summary", init=False, repr=False)
    _trio_counter: int = field(default=0, init=False, repr=False)

    def __post_init__(self) -> None:
        if "pickle_file" not in self.args:
            with open(str(self.args.resource_config), mode="r") as file:
                self._slurm_resources = load(file)

        if self.args.dry_run:
            self._logger_msg = f"[DRY_RUN] - [{self._phase}]"
        else:
            self._logger_msg = f"[{self._phase}]"
    
    def load_metadata(self) -> None:
        """
        Read in and save the metadata file as a dictionary.
        """
        # Confirm data input is an existing file
        metadata = TestFile(str(self._metadata_input), self.logger)
        metadata.check_existing(logger_msg=self._logger_msg, debug_mode=self.args.debug)
        if metadata.file_exists:
            # Read in the csv file
            with open(
                metadata.file, mode="r", encoding="utf-8-sig"
            ) as data:
                dict_reader = DictReader(data)

                # Remove whitespace within the CSV input
                self._data_list = [
                    dict((k.strip(), v.strip()) for k, v in row.items() if v)
                    for row in dict_reader
                ]
                self._total_samples = len(self._data_list)
        else:
            self.logger.error(
                f"{self._logger_msg}: unable to load metadata file | '{metadata.file}'"
            )
            raise ValueError("Invalid Input File")

    def load_variables(self) -> None:
        """
        Define python variables.
        """
        if self.args.dry_run:
            self._logger_msg = f"[DRY_RUN] - [{self._phase}]"
        else:
            self._logger_msg = f"[{self._phase}]"

        if "metadata" in self.args:
            self._metadata_input = Path(self.args.metadata)
            self.load_metadata()

        output = Path(self.args.outpath).resolve()

        if "." in output.name:
            _path = output.parent
            _file_name = f"{output.name}.mie.csv"
        else:
            _path = output
            _file_name = f"mie.csv"

        if not self.args.dry_run:
            output.mkdir(parents=True, exist_ok=True)

        self._csv_output = Files(
            path_to_file=_path / _file_name,
            logger=self.logger,
            logger_msg=self._logger_msg,
            dryrun_mode=self.args.dry_run,
            debug_mode=self.args.debug,
        )
        self._csv_output.check_status()

        # Initalize an empty Iteration() to store paths
        self._itr = Iteration(logger=self.logger, args=self.args)    

    def check_sample(self) -> None:
        self._pickled_data = SummarizeResults(
            sample_metadata=self._data,
            output_file=self._csv_output,
            args=self.args,
        )
        self._pickled_data.get_sample_info()
        
        if self.args.debug:
            if not self._pickled_data._parent_record and not self._pickled_data._contains_valid_trio:
                self.logger.debug(
                    f"{self._logger_msg}: not a valid trio... SKIPPING AHEAD"
                )

    def process_sample(self, pkl_suffix: Union[str, None] = None, store_data: bool = False) -> None:
        """
        Generate the pickled data file, and the bash command for processing each sample.
        """
        self._pickled_data._index = self._index
        self._clean_file_path = self._pickled_data._input_file._test_file.clean_filename
        
        if pkl_suffix is None:
            self._pickle_file = TestFile(
                Path(f"{self._clean_file_path}.pkl"),
                logger=self.logger,
            )
        else:
            if "trio" in pkl_suffix.lower():
                _sample_name = self._clean_file_path.stem
                _pickle_path = Path(self._clean_file_path).parent.parent / "TRIOS" / f"{_sample_name}.{pkl_suffix}.pkl"
            else:
                _pickle_path = Path(f"{self._clean_file_path}.{pkl_suffix}.pkl")

            self._pickle_file = TestFile(
                _pickle_path,
                logger=self.logger,
            )

        slurm_cmd = [
            "python3",
            "./triotrain/summarize/post_process.py",
            "--pickle-file",
            self._pickle_file.file,
        ]
        cmd_string = " ".join(slurm_cmd)
        
        self._pickle_file.check_existing(logger_msg=self._logger_msg)
        
        if self._pickle_file.file_exists:
            from post_process import Stats
            self._check_stats = Stats(pickled_data=self._pickled_data)
            self._check_stats.find_stats_output()
            # if self._check_stats._output.file_exists:
            #     print("STATS EXISTS!")
            # breakpoint()
            if self.args.overwrite is False or self._check_stats._output.file_exists:
                if self.args.debug:
                    self.logger.debug(
                            f"{self._logger_msg}: --overwrite=False; unable to replace an exisiting file... SKIPPING AHEAD"
                        )
                return
            else:
                if self.args.overwrite and self._check_stats._output.file_exists:
                    self.logger.info(
                        f"{self._logger_msg}: --overwrite=True; replacing an exisiting file | '{_sample_name}.stats'"
                    )
        
        self.logger.info(
            f"{self._logger_msg}: missing summary stats file\t\t| '{self._check_stats._output.file_name}'"
            )
        
        if self._command_list:
            self._command_list.append(cmd_string)
        else:
            self._command_list = [cmd_string]

        if self.args.dry_run:
            self.logger.info(
                f"{self._logger_msg}: pretending to create pickle file\t| '{self._pickle_file.path.name}'"
            )
        else:
            if store_data:
                preserve(
                    item=self._pickled_data,
                    pickled_path=self._pickle_file,
                    overwrite=self.args.overwrite,
                    msg=f"{self._logger_msg}: ",
                    )

    def make_job(self, job_name: str) -> Union[SBATCH, None]:
        """
        Define the contents of the SLURM job for the rtg-mendelian phase for TrioTrain Pipeline.
        """
        # Skip jobs whenever there is nothing to execute
        if len(self._command_list) < 1:
            return
        
        self._itr.job_dir = self._pickle_file.path.parent 
        self._itr.log_dir = self._pickle_file.path.parent / "logs"
        if not self._itr.log_dir.exists():
            if self.args.dry_run:
                self._itr.logger.info(
                    f"{self._logger_msg}: pretending to create a new directory | '{self._itr.log_dir}'"
                )
            else:
                self._itr.logger.info(
                    f"{self._logger_msg}: creating a new directory | '{self._itr.log_dir}'"
                )
                self._itr.log_dir.mkdir(parents=True, exist_ok=True)

        self._job_name = job_name

        # Initialize a SBATCH Object
        slurm_job = SBATCH(
            itr=self._itr,
            job_name=self._job_name,
            error_file_label=self._pickled_data._caller,
            handler_status_label=None,
            logger_msg=self._logger_msg,
        )

        if slurm_job.check_sbatch_file():
            if self.args.overwrite:
                self._itr.logger.info(
                    f"{self._logger_msg}: --overwrite=True, re-writing the existing SLURM job now..."
                )
            else:
                self._itr.logger.info(
                    f"{self._logger_msg}: --overwrite=False, SLURM job file already exists."
                )
                self._num_skipped += 1
                return
        else:
            if self._itr.debug_mode:
                self._itr.logger.debug(f"{self._logger_msg}: creating file job now... ")
        
        slurm_cmd = (
            slurm_job._start_conda
            + ["conda activate miniconda_envs/beam_v2.30"]
            + self._command_list
        )

        slurm_job.create_slurm_job(
            None,
            command_list=slurm_cmd,
            overwrite=self.args.overwrite,
            **self._slurm_resources["summary"],
        )
        return slurm_job

    def submit_job(self, index: int = 0, total: int = 1) -> None:
        """
        Submit SLURM jobs to queue.
        """
        self._total_samples = total
        # Only submit a job if a new SLURM job file was created
        if self._slurm_job is None:
            return

        if self._itr.dryrun_mode:
            self._slurm_job.display_job()
        else:
            self._slurm_job.write_job()

        # Submit the training eval job to queue
        submit_slurm_job = SubmitSBATCH(
            sbatch_dir=self._itr.job_dir,
            job_file=f"{self._job_name}.sh",
            label="None",
            logger=self.logger,
            logger_msg=self._logger_msg,
        )

        submit_slurm_job.build_command()
        submit_slurm_job.display_command(
            current_job=(index + 1),
            total_jobs=total,
            display_mode=self._itr.dryrun_mode,
            debug_mode=self._itr.debug_mode,
        )

        if self._itr.dryrun_mode:
            self._job_nums.append(generate_job_id())
            self._num_processed += 1
        else:
            submit_slurm_job.get_status(
                debug_mode=self._itr.debug_mode,
                current_job=(index + 1),
                total_jobs=total,
            )

            if submit_slurm_job.status == 0:
                self._num_submitted += 1
                self._job_nums.append(submit_slurm_job.job_number)
            else:
                self.logger.error(
                    f"{self._logger_msg}: unable to submit SLURM job",
                )
                self._job_nums.append(None)

    def check_submission(self) -> None:
        """
        Check if the SLURM job file was submitted to the SLURM queue successfully
        """
        if self._num_processed != 0 and self._num_skipped != 0:
            completed = self._num_processed + self._num_skipped
        elif self._num_processed != 0:
            completed = self._num_processed
        else:
            completed = self._num_skipped

        # Look at job number list to see if all items are 'None'
        _results = check_if_all_same(self._job_nums, None)

        if _results is False:
            print(
                f"============ {self._logger_msg} Job Numbers - {self._job_nums} ============"
            )
        elif completed == self._total_samples:
            self.logger.info(
                f"{self._logger_msg}: no SLURM jobs were submitted... SKIPPING AHEAD"
            )
        elif self._itr.debug_mode and completed == self._total_samples:
            self.logger.debug(
                f"{self._logger_msg}: no SLURM jobs were submitted... SKIPPING AHEAD"
            )
        else:
            self.logger.warning(
                f"{self._logger_msg}: expected SLURM jobs to be submitted, but they were not",
            )
            self.logger.warning(
                f"{self._logger_msg}: fatal error encountered, unable to proceed further with pipeline.\nExiting... ",
            )
            exit(1)
