#!/usr/bin/python3

from genericpath import exists
import shutil
import subprocess
import glob
import os
import pysam
import csv
import threading
import gzip
from random import seed
from random import sample, randint
from Bio import SeqIO
from Bio.Seq import Seq
# seed random number generator
seed(42)

# paths to profiler binaries
centrifuge_bin_path = "/media/jens/xavierSSD/software/centrifuge/install/bin"
kraken2_bin_path = "/home/jens/software/kraken2/install"
pbmodeldir = "/home/jens/software/pbsim2/data"

# global variables
threads = 6


mutex  = threading.Lock()

class ThreadPool(object):
    def __init__(self):
        super(ThreadPool, self).__init__()
        self.active = []
        self.lock = threading.Lock()
    def makeActive(self, name):
        with self.lock:
            self.active.append(name)
            #logging.debug('Running: %s', self.active)
    def makeInactive(self, name):
        with self.lock:
            self.active.remove(name)
            #logging.debug('Running: %s', self.active)


def download_ref_seq_genomes(dir):

    if not os.path.exists(dir + "/refseq-abfv"):
        cmd = ["genome_updater.sh","-d","refseq","-g", "archaea,bacteria,fungi,viral", "-l", "complete genome,chromosome", "-f", "genomic.fna.gz",
           "-o", dir + "/refseq-abfv", "-t", "12", "-A", "species:1", "-m", "-a", "-p"]
        subprocess.run(cmd)

    for entry in glob.glob(dir + "/refseq-abfv/20*"):
        dir = entry

    if not os.path.exists(dir + "/taxdump"):
        os.mkdir(dir + "/taxdump")
        cmd = ["tar", "-zxvf", dir + "/taxdump.tar.gz", "-C", dir + "/taxdump"]
        subprocess.run(cmd)
    
    if not os.path.exists(dir + "/refseq_accessions_taxonomy.csv"):
        cut_cmd = ["cut", "-f", "1,7,20", dir + "/assembly_summary.txt"]
        cut1_out = subprocess.Popen(cut_cmd ,stdout=subprocess.PIPE, stderr=subprocess.DEVNULL, universal_newlines=True)
        lineage_cmd = ["taxonkit", "lineage", "-i", "2", "-r", "-n", "-L", "--data-dir", dir + "/taxdump"]
        lineage_out = subprocess.Popen(lineage_cmd ,stdin=cut1_out.stdout, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL, universal_newlines=True)
        reformat_cmd = ["taxonkit", "reformat", "-I", "2", "-P", "-t", "--data-dir", dir + "/taxdump"]
        reformat_out = subprocess.Popen(reformat_cmd, stdin=lineage_out.stdout, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL, universal_newlines=True)
        cut2_cmd = ["cut", "-f", "1,2,3,4,6,7"]
        fout = open(dir + "/refseq_accessions_taxonomy.csv", 'w')
        p2 = subprocess.run(cut2_cmd, stdin=reformat_out.stdout, stdout=fout)
        fout.close()

#    os.chdir(dir)
#    if not os.path.exists(dir + "/refseq-abfv-aa"):
#        cmd = ["split", "-l", "1100", dir + "/refseq_accessions_taxonomy.csv", "refseq-abfv-"]
#        subprocess.run(cmd)

    return dir

def prepare_centrifuge_input(dir):
    if not os.path.exists(dir + "/centrifuge_data"):
        os.mkdir(dir + "/centrifuge_data")

    if not os.path.exists(dir + "/centrifuge_data/conversion_table.tsv"):
        cmd = ["./centrifuge_conversion_table.py", "-t", dir + "/kmcp_data/taxid.map", "-s", dir + "/files",
               "-o", dir + "/centrifuge_data/conversion_table.tsv"]
        subprocess.run(cmd)
    if not os.path.exists(dir + "/files/all.fna.gz"):
        with gzip.open(dir + "/files/all.fna.gz", 'wt') as f_out:
            for filename in glob.glob(dir + "/files/*.fna.gz"):
                with gzip.open(filename, 'rt') as f_in:
                    for idx, record in enumerate(SeqIO.parse(f_in, "fasta")):
                        r=SeqIO.write(record, f_out, 'fasta')
                        if r!=1: print('Error while writing sequence:  ' + record.id)
        
def prepare_kraken2_input(dir):
    if not os.path.exists(dir + "/kraken2_data"):
        os.mkdir(dir + "/kraken2_data")
    
    download_cmd = [kraken2_bin_path + "/kraken2-build", "--download-taxonomy", "--db", dir + "/kraken2_data/refseq-abfv"]
    #subprocess.run(download_cmd)

    if not os.path.exists(dir + "/kraken2_data/all_seq.fna"):
        prepare_fasta_cmd = ["/home/jens/software/TaxorScripts/python/add_kraken_header.py", "-t", dir + "/centrifuge_data/conversion_table.tsv", "-f",
                             dir + "/files/all.fna", "-o", dir + "/kraken2_data/all_seq.fna"]
        subprocess.run(prepare_fasta_cmd)

def execute_thread(s, pool, filename, simdir):
    #logging.debug('Waiting to join the pool')
    with s:
        name = threading.currentThread().getName()
        pool.makeActive(name)
        create_simulated_read_sets(filename, simdir)
        pool.makeInactive(name)

def select_ref_genomes(dir, gen_select_dir, simdir):
    files = []
    with open(dir + "/refseq_accessions_taxonomy.csv", newline='') as csvfile:
        for line in csvfile:
            lsplit = line.split("\t")
            file = lsplit[2]
            fsplit = file.split("/")
            fname = fsplit[-1] + "_genomic.fna.gz"
            files.append(dir + "/files/" + fname)

    sampl = sample(range(0, len(files)), 100)
    
    if not os.path.exists(gen_select_dir):
        os.mkdir(gen_select_dir)

    if not os.path.exists(simdir):
        os.mkdir(simdir)
        
    pool = ThreadPool()
    s = threading.Semaphore(8)
    i = 0
    threads = []
    sim_genomes = {}
    for value in sampl:
        
        head, tail = os.path.split(files[value])
        newname = os.path.splitext(tail)[0]
        newname = os.path.splitext(newname)[0]
        namesplit = newname.split("_")
        acc = namesplit[0] + "_" + namesplit[1]
        seqkit_cmd = ["seqkit", "stats", "-T", files[value]]
        cmdout = subprocess.Popen(seqkit_cmd ,stdout=subprocess.PIPE, stderr=subprocess.DEVNULL, universal_newlines=True)
        out,err = cmdout.communicate()
        statlines = out.split("\n")
    
        for idx, line in enumerate(statlines):
            if idx == 1:
                #print(acc)
                entries = str(line).split()
                if not acc in sim_genomes.keys():
                    sim_genomes[acc] = {}
                    sim_genomes[acc]['bases'] = 0
                    sim_genomes[acc]['reads'] = 0
                    sim_genomes[acc]['genome_size'] = entries[4]
                
        #path = dir + "gtdb_genomes_reps_r207/" + str(files[value][0:3]) + "/" + str(files[value][4:7]) + "/" + str(files[value][7:10]) + "/" + str(files[value][10:13])
        head, tail = os.path.split(files[value])
        ref_file = gen_select_dir + "/" + tail
        if not os.path.exists(ref_file):
            cutnruns_cmd = ["/home/jens/software/Scripts/cut_nruns.py", "-i", files[value], "-o", ref_file]
            subprocess.run(cutnruns_cmd)
            
        i += 1
        t = threading.Thread(target=execute_thread, name='thread_'+str(i), args=(s, pool, ref_file, simdir))
            #create_simulated_read_sets(ref_file)
            #break
        threads.append(t)
        t.start()
        #break
    for t in threads:
        t.join()

    
    with gzip.open(dir + "/refseq-abfv-sim100.fastq.gz", 'wt') as outfile:
        for filename in glob.glob(simdir + "/*_genomic.fastq.gz"):
            with gzip.open(filename, 'rt') as fq_in:
                for idx, record in enumerate(SeqIO.parse(fq_in, "fastq")):
                    idsplit = str(record.id).split("_")
                    acc = idsplit[0] + "_" + idsplit[1]
                    
                    if (len(record.seq) >= 1000):
                        r=SeqIO.write(record, outfile, 'fastq')
                        if r!=1: print('Error while writing sequence:  ' + record.id)
                        sim_genomes[acc]['bases'] += len(record.seq)
                        sim_genomes[acc]['reads'] +=1

    return sim_genomes

def create_simulated_read_sets(reffile, simdir):

    head, tail = os.path.split(reffile)
    newname = os.path.splitext(tail)[0]
    newname = os.path.splitext(newname)[0]
    namesplit = newname.split("_")
    acc = namesplit[0] + "_" + namesplit[1] + "_"
    guzcmd = ["gunzip", reffile]
    subprocess.run(guzcmd)
    reffile = os.path.splitext(reffile)[0]
    depth = randint(5, 100)
    pbsimcmd = ["pbsim", "--prefix", newname, "--id-prefix", acc ,"--depth", str(depth), "--accuracy-mean", "0.95", "--length-min","1000", "--hmm_model", pbmodeldir + "/R103.model", "--seed" , "20230412", reffile]
    subprocess.run(pbsimcmd)

    gzcmd = ["gzip", reffile]
    subprocess.run(gzcmd)

    outfilename = simdir + "/" + newname + ".fastq"
    with open(outfilename, 'wb') as outfile:
        for filename in glob.glob(newname + "_*.fastq"):
            if filename == outfilename:
                # don't want to copy the output into the output
                continue
            with open(filename, 'rb') as readfile:
                shutil.copyfileobj(readfile, outfile)
        
        for file in glob.glob(newname + "_*"):
            os.remove(file)

    gzcmd = ["gzip", outfilename]
    subprocess.run(gzcmd)

def create_simulated_sample_tsv(dir, simdir, sim_genomes):

    all_bases = 0
    for filename in glob.glob(simdir + "/*_genomic.fastq.gz"):
        head, tail = os.path.split(filename)
        newname = os.path.splitext(tail)[0]
        newname = os.path.splitext(newname)[0]
        namesplit = newname.split("_")
        acc = namesplit[0] + "_" + namesplit[1]
        seqkit_cmd = ["seqkit", "stats", "-T", dir + "/files/" + newname + ".fna.gz"]
        cmdout = subprocess.Popen(seqkit_cmd ,stdout=subprocess.PIPE, stderr=subprocess.DEVNULL, universal_newlines=True)
        out,err = cmdout.communicate()
        statlines = out.split("\n")
    
        for idx, line in enumerate(statlines):
            if idx == 1:
                #print(acc)
                entries = str(line).split()
                if not acc in sim_genomes.keys():
                    sim_genomes[acc] = {}
                    sim_genomes[acc]['bases'] = 0
                    sim_genomes[acc]['reads'] = 0
                    sim_genomes[acc]['genome_size'] = entries[4]
    
    with gzip.open(dir + "/refseq-abfv-sim100.fastq.gz", 'rt') as fastq_file:
        for idx, record in enumerate(SeqIO.parse(fastq_file, "fastq")):
            idsplit = str(record.id).split("_")
            acc = idsplit[0] + "_" + idsplit[1]
            sim_genomes[acc]['bases'] += len(record.seq)
            all_bases += sim_genomes[acc]['bases']
            sim_genomes[acc]['reads'] +=1

    sumdepth = 0.0
    with open(dir + "/assembly_summary.txt", newline='') as csvfile:
        for line in csvfile:
            lsplit = line.split("\t")
            if lsplit[0] in sim_genomes.keys():
                sim_genomes[lsplit[0]]["taxon_id"] = lsplit[6]
                sim_genomes[lsplit[0]]["name"] = lsplit[7]
                sim_genomes[lsplit[0]]["depth"] = float(sim_genomes[lsplit[0]]['bases']) / float(sim_genomes[lsplit[0]]['genome_size'])
                sumdepth += sim_genomes[lsplit[0]]["depth"]
    
    with open(dir + "/refseq-abfv-sim100-genomes.tsv", 'wt') as csvfile:
        csvfile.write("accession\ttaxon_id\tname\tbases\treads\tgenome_size\tgenome_abundance\n")
        for acc in sim_genomes.keys():
            abundance = float(sim_genomes[acc]["depth"]) / float(sumdepth)
            csvfile.write(acc + "\t" + sim_genomes[acc]['taxon_id'] + "\t" + sim_genomes[acc]['name'] + "\t" +
                          str(sim_genomes[acc]['bases']) + "\t" + str(sim_genomes[acc]['reads']) + "\t" + 
                          str(sim_genomes[acc]['genome_size']) + "\t" + str(abundance) + "\n")

def prepare_zymo_references(reference_dir, reference_file, refs2taxon_file):

    refs = {}
    with open(refs2taxon_file, 'rt') as f_in:
        idx = 0
        for line in f_in:
            idx += 1
            if idx == 1:
                continue
            lsplit = line.strip().split("\t")
            refs[lsplit[0]] = lsplit[1]

    if not os.path.exists(reference_file):
        with gzip.open(reference_file ,'wt') as f_out:
            for filename in glob.glob(reference_dir + "/*.fasta"):
                head, tail = os.path.split(filename)
                fname = os.path.splitext(tail)[0]
                if "viral" in reference_file:
                    ref = fname
                else:
                    namesplit = fname.split("_")
                    ref = namesplit[0] + "_" + namesplit[1]
                with open(filename, 'rt') as f_in:
                    for idx, record in enumerate(SeqIO.parse(f_in, "fasta")):
                        new_id = record.id + "|kraken:taxid|" + str(refs[ref])
                        new_record = SeqIO.SeqRecord(Seq(record.seq), id=new_id, description="")
                        r=SeqIO.write(new_record, f_out, 'fasta')
                        if r!=1: print('Error while writing sequence:  ' + new_record.id)
    


def mapReads(reference_file, read_file):
    newname = os.path.splitext(read_file)[0]
    newname = os.path.splitext(newname)[0]
    bamfile = newname + ".bam"

    if not os.path.exists(bamfile):
        mm2_cmd = ["minimap2","-ax", "map-ont","--secondary=no", reference_file, read_file]
        minimapoutput = subprocess.Popen(mm2_cmd, stdout=subprocess.PIPE,stderr=subprocess.DEVNULL)
        samsortcmd = ["samtools", "sort","-O","BAM", "-o", bamfile, "--reference", reference_file, "--threads", "3"]
        samsortout = subprocess.Popen(samsortcmd ,stdin=minimapoutput.stdout, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
        minimapoutput.stdout.close()
        samsortout.communicate()
        statcommand = ["samtools", "index", bamfile]
        subprocess.run(statcommand)

    return bamfile

def mapShortReads(reference_file, read_file):
    newname = os.path.splitext(read_file)[0]
    newname = os.path.splitext(newname)[0]
    bamfile = newname + ".bam"

    if not os.path.exists(bamfile):
        bwaind_cmd = ["bwa","index", reference_file]
        subprocess.run(bwaind_cmd)
        bwaaln_cmd = ["bwa","aln", "-n", "0.01", "-t", "6", "-f", newname + ".sai", reference_file, read_file]
        subprocess.run(bwaaln_cmd)
        bwasamse_cmd = ["bwa","samse", "-n", "0.01", reference_file, newname + ".sai", read_file]
        minimapoutput = subprocess.Popen(bwasamse_cmd, stdout=subprocess.PIPE,stderr=subprocess.DEVNULL)
        samsortcmd = ["samtools", "sort","-O","BAM", "-o", bamfile, "--reference", reference_file, "--threads", "3"]
        samsortout = subprocess.Popen(samsortcmd ,stdin=minimapoutput.stdout, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
        minimapoutput.stdout.close()
        samsortout.communicate()
        statcommand = ["samtools", "index", bamfile]
        subprocess.run(statcommand)

    return bamfile

def get_mapping_data(dir, bamfile, read_file, refs2taxon_file):
    newname = os.path.splitext(read_file)[0]
    newname = os.path.splitext(newname)[0]
    new_fastq = newname + "-len1000-groundTruth.fq.gz"
    samfile = pysam.AlignmentFile(bamfile, "rb")
    read_ids = {}
    taxon_data = {}
    for read in samfile.fetch(until_eof=True):
        #if read.query_length < 1000:
        #    continue
        if not read.reference_name == None:
            taxon_id = str(read.reference_name).split("|")[-1]
            read_ids[read.query_name] = read.query_name + "|" + taxon_id
            if not taxon_id in taxon_data.keys():
                taxon_data[taxon_id] = {}
                taxon_data[taxon_id]["reads"] = 0
                taxon_data[taxon_id]["bases"] = 0
            taxon_data[taxon_id]["reads"] += 1
            taxon_data[taxon_id]["bases"] += read.query_length

    if not os.path.exists(newname + ".metadata"):
        refs = {}
        with open(refs2taxon_file, 'rt') as f_in:
            idx = 0
            for line in f_in:
                idx += 1
                if idx == 1:
                    continue
                lsplit = line.strip().split("\t")
                refs[lsplit[0]] = lsplit[1]

        with open(newname + ".metadata", 'wt') as f_out:
            for ref in refs.keys():
                if refs[ref] in taxon_data.keys():
                    f_out.write(ref + "\t" + refs[ref] + "\t" + str(taxon_data[refs[ref]]["reads"]) + "\t" + str(taxon_data[refs[ref]]["bases"]) + "\n")
                else:
                    f_out.write(ref + "\t" + refs[ref] + "\t0\t0\n")


    with gzip.open(new_fastq, 'wt') as fq_out:
        with gzip.open(read_file, 'rt') as fq_in:
            for idx, record in enumerate(SeqIO.parse(fq_in, "fastq")):
                if record.id in read_ids.keys():
                    record.id = read_ids[record.id]
                    record.description = ""
                    r=SeqIO.write(record, fq_out, 'fastq')
                    if r!=1: print('Error while writing sequence:  ' + record.id)


def main():
    ref_seq_dir = "/media/jens/ExtremePro"
    working_dir = download_ref_seq_genomes(ref_seq_dir)
    #sim_genomes = select_ref_genomes(working_dir, working_dir + "/selected_genomes", working_dir + "/simulated")
    #sim_genomes = {}
    #create_simulated_sample_tsv(working_dir, working_dir + "/simulated", sim_geno
    
    zymo_ref_file = working_dir + "/real_mock/Genomes/Zymo_D6330_references.fna.gz"
    #prepare_zymo_references(working_dir + "/real_mock/Genomes/D6330", zymo_ref_file, working_dir + "/real_mock/Refs2TaxonIds_D6330.txt")

    # prepare ZymoR10.3 ground truth data
    #bamfile = mapReads(zymo_ref_file, working_dir + "/real_mock/Zymo-GridION-EVEN-3Peaks-R103-merged.fq.gz")
    #get_mapping_data(working_dir, bamfile, working_dir + "/real_mock/Zymo-GridION-EVEN-3Peaks-R103-merged.fq.gz")

    # prepare ZymoQ20
    #bamfile = mapReads(zymo_ref_file, working_dir + "/real_mock/ERR5396170.fastq.gz")
    #get_mapping_data(working_dir, bamfile, working_dir + "/real_mock/ERR5396170.fastq.gz", working_dir + "/real_mock/Refs2TaxonIds_D6330.txt")

    # map Illumina
    bamfile = mapShortReads(zymo_ref_file, working_dir + "/real_mock/SRR12324253.fastq.gz")
    get_mapping_data(working_dir, bamfile, working_dir + "/real_mock/SRR12324253.fastq.gz", working_dir + "/real_mock/Refs2TaxonIds_D6330.txt")

    # prepare HiFi_D6331
    #zymo_ref_file = working_dir + "/real_mock/Genomes/Zymo_D6331_references.fna.gz"
    #prepare_zymo_references(working_dir + "/real_mock/Genomes/D6331", zymo_ref_file, working_dir + "/real_mock/Refs2TaxonIds_D6331.txt")
    #bamfile = mapReads(zymo_ref_file, working_dir + "/real_mock/SRR13128014.fastq.gz")
    #get_mapping_data(working_dir, bamfile, working_dir + "/real_mock/SRR13128014.fastq.gz", working_dir + "/real_mock/Refs2TaxonIds_D6331.txt")

    

    #bam

    # map short reads

#    prepare_kmcp_input(working_dir)
#    prepare_centrifuge_input(working_dir)
#    prepare_kraken2_input(working_dir)



        
if __name__ == "__main__":
    main()