#!/usr/bin/env python

# MIDAS: Metagenomic Intra-species Diversity Analysis System
# Copyright (C) 2015 Stephen Nayfach
# Freely distributed under the GNU General Public License (GPLv3)

import sys, os, subprocess, gzip, csv, Bio.SeqIO, numpy as np
from collections import defaultdict
from time import time
from midas import utility
from midas.run import stream_bam

def get_barcode_from_readname(args, read_name):

    if args['use_10x_barcodes']:
        # do one thing
        # The barcode string encoded in the fastq file
        barcode_str = read_name.split(":")[-1]
        barcode_str = barcode_str.split(",")[0]
            
    else:
        # read-based pseudo-barcoding
        barcode_str = read_name
            
    return barcode_str
    
class BarcodeDB:
    """ Has methods for matching and storing barcodes """
    def __init__(self):
        
        self.barcode_id_map = {}
        self.id_barcode_map = {}
        self.id_count_map = {}
        self.current_id = 0
        
    def get_barcode_id(self, barcode_str):
        
        if barcode_str not in self.barcode_id_map:
            barcode_id = self.current_id
            self.current_id+=1
            
            self.barcode_id_map[barcode_str] = barcode_id
            self.id_barcode_map[barcode_id] = barcode_str
            self.id_count_map[barcode_id] = 0
            
        barcode_id = self.barcode_id_map[barcode_str]
        self.id_count_map[barcode_id] += 1
        
        return barcode_id
        
    def get_barcode_str_count(self, barcode_id):
        barcode_str = self.id_barcode_map[barcode_id]
        barcode_count = self.id_count_map[barcode_id]
        
        return barcode_str, barcode_count 
        
    def get_all_barcode_ids(self):
        return self.id_barcode_map.keys() 
        
class Species:
    """ Base class for species """
    def __init__(self, id):
        self.id = id
        self.paths = {}
        self.genes = []
        self.pangenome_size = 0
        self.reads = 0
        self.bases = 0.0
        self.depth = 0.0
        self.markers = defaultdict(float)
        
    def init_ref_db(self, ref_db):
        """ Set paths to input files """
        self.dir = '%s/pan_genomes/%s' % (ref_db, self.id)
        for ext in ['', '.gz']:
            for file in ['centroids.ffn', 'cluster_info.txt', 'gene_info.txt']:
                inpath = '%s/%s%s' % (self.dir, file, ext)
                if os.path.isfile(inpath):
                    self.paths[file] = inpath

def initialize_species(args):
    species = {}
    splist = '%s/%s/species.txt' % (args['outdir'], args['barcode_dir'])
    if args['build_db']:
        from midas.run.species import select_species
        with open(splist, 'w') as outfile:
            for id in select_species(args):
                species[id] = Species(id)
                outfile.write(id+'\n')
    elif os.path.isfile(splist):
        for line in open(splist):
            species[id] = Species(line.rstrip())
    for sp in species.values():
        sp.init_ref_db(args['db'])
    return species

def load_gene_cluster_map(args, species):
    
    for sp in species.values():
    
        db = args['db']
        pid = args['cluster_pid']
        
        sp.map = {}
        sp.inverse_map = {}
        for ext in ['', '.gz']:
            path = '/'.join([db, 'pan_genomes', sp.id, 'gene_info.txt%s' % ext])
            if os.path.isfile(path):
                sp.gene_info = path
        for r in utility.parse_file(sp.gene_info):
            sp.map[r['centroid_99']] =  r['centroid_%s' % pid]
            
            if r['centroid_%s' % pid] not in sp.inverse_map:
                sp.inverse_map[r['centroid_%s' % pid]] = set()
                
            sp.inverse_map[r['centroid_%s' % pid]].add(r['centroid_99'])
            

class Gene:
    """ Base class for gene """
    def __init__(self, id):
        self.id = id
        self.reads = 0
        self.barcodes = {}
        self.length = 0
        self.marker_id = None

class SNP:
    """ Base class for snps """
    def __init__(self, position):
        self.position = position 
        self.reads = 0
        self.ref_barcodes = {}
        self.alt_barcodes = {}
        self.other_barcodes = {}
        

def initialize_genes(args, species):
    """ Initialize Gene objects """
    genes = {}
    # fetch gene_id, species_id, gene length
    for sp in species.values():
        path = sp.paths['centroids.ffn']
        file = utility.iopen(path)
        for seq in Bio.SeqIO.parse(file, 'fasta'):
            genes[seq.id] = Gene(seq.id)
            genes[seq.id].species_id = sp.id
            genes[seq.id].length = len(seq.seq)
            sp.pangenome_size += 1
        file.close()
    # fetch marker_id
    path = '%s/marker_genes/phyeco.map' % args['db']
    file = utility.iopen(path)
    reader = csv.DictReader(file, delimiter='\t')
    for r in reader:
        if r['gene_id'] in genes:
            genes[r['gene_id']].marker_id=r['marker_id']
    file.close()
    
    # calculate desired genes (for now, all!)
    desired_genes = {}
    for gene_id in genes.keys():
        gene = genes[gene_id]
        
        if gene.species_id not in desired_genes:
            desired_genes[gene.species_id] = set()
        
        desired_genes[gene.species_id].add(gene_id)
        
    return genes, desired_genes

def initialize_snps(args, species):
    """ Initialize SNP objects """
    
    # create SNPs dictionary
    snps = {}
    for sp in species.values():
        species_id = sp.id
        snps[species_id] = []
    
    desired_snps = {}
    
    if args['collate_snp_barcodes']:
        # create list of desired snps
        file = open(args['desired_snps_filename'],"r")
        file.readline() # header
        for line in file:
            items = line.split("\t")
            species_id = items[0]
            contig = items[1]
            loc = long(items[2])
        
            desired_snps[(contig, loc)] = species_id
        
    return snps, desired_snps
    
def barcode_coverage(args, species, genes, desired_genes, snps, desired_snps, barcode_db,debug=False):
    """ Compute barcodes for species_id and write results to disk """
    if args['collate_gene_barcodes']:
        collate_gene_barcodes(args, species, genes, desired_genes, barcode_db,debug)
    if args['collate_snp_barcodes']:
        collate_snp_barcodes(args, species, snps, desired_snps, barcode_db,debug)
    write_results(args, species, genes, desired_genes, snps, desired_snps, barcode_db)


        
#############
#
# Gets barcodes for reads that align to pangenome genes
#
##############
def collate_gene_barcodes(args, species, genes, desired_genes, barcode_db, debug=False):
    """ Collate barcodes that map to each gene across pangenomes """
    import pysam
    bam_path = '/'.join([args['outdir'], 'genes/temp/pangenomes.bam'])
    aln_file = pysam.AlignmentFile(bam_path, "rb")
    i, j = 0,0
    # loop over alignments, sum values per gene
    for index, aln in enumerate(aln_file.fetch(until_eof = True)):
        i += 1
        if stream_bam.compute_perc_id(aln) < args['mapid']:
            continue
        elif stream_bam.compute_aln_cov(aln) < args['aln_cov']:
            continue
        elif np.mean(aln.query_qualities) < args['readq']:
            continue
        elif aln.mapping_quality < args['mapq']:
            continue
        else:
            # Read matched gene!
            raw_gene_id = aln_file.getrname(aln.reference_id)   
            species_id = genes[raw_gene_id].species_id
            
            # get mapped cluster
            gene_id = species[species_id].map[raw_gene_id] 
            
            if gene_id not in desired_genes[species_id]:
                continue 
                
            # Get name of read and split off barcode
            read_name = aln.query_name
            
            barcode_str = get_barcode_from_readname(args, read_name)
            
            # A compressed version of the barcode suitable for storage
            barcode_id = barcode_db.get_barcode_id(barcode_str)
            
            genes[gene_id].reads += 1
            
            if barcode_id not in genes[gene_id].barcodes:
                genes[gene_id].barcodes[barcode_id] = 0
            
            genes[gene_id].barcodes[barcode_id] += 1
            
            j += 1
    
        # Inserted for debugging purposes!
        if args['debug'] and j>100000:
            break       
            
    print("  total aligned reads: %s" % i)
    print("  total mapped reads: %s" % j)
    
    # loop over genes, sum values per species
    for gene in genes.values():
        species[gene.species_id].reads += gene.reads
        species[gene.species_id].genes.append(len(gene.barcodes)*1.0/gene.length)
    
    # loop over species, compute summaries
    for sp in species.values():
        sp.covered_genes = sum([1 for _ in sp.genes if _ > 0])
        sp.mean_coverage = sum([_ for _ in sp.genes if _ > 0])
        if sp.covered_genes > 0:
            sp.mean_coverage = sp.mean_coverage*1.0/sp.covered_genes
            
        sp.fraction_covered = sp.covered_genes/float(sp.pangenome_size)

##############
#
# Gets barcodes for the subset of SNPs supplied in the desired_snps_file
#
##############
def collate_snp_barcodes(args, species, snps, desired_snps, barcode_db, debug=False):
    # Fill in with Stephen's mpileup algorithm. 
    """ Collate barcodes that map to the supplied list of SNPs across refgenomes """
    
    import pysam
    
    # open ref genome alignment file
    bam_path = '/'.join([args['outdir'], 'snps/temp/genomes.bam'])
    bai_path = '/'.join([args['outdir'], 'snps/temp/genomes.bam.bai'])
    
    # index bam file
    sys.stderr.write("Indexing bam file...\t")
    pysam.index(bam_path, bai_path)
    sys.stderr.write("Done!\n")
    
    aln_file = pysam.AlignmentFile(bam_path, "rb")

    fasta_path = '/'.join([args['outdir'],'snps/temp/genomes.fa'])
    fasta_file = pysam.FastaFile(fasta_path)

    # Iterate through all the positions in the ref genome
    
    num_processed = 0
    for pileupcolumn in aln_file.pileup():
    
        contig_name = pileupcolumn.reference_name
        contig_position = pileupcolumn.pos+1
        position = (contig_name, contig_position)
        
        if position not in desired_snps:
            continue
            # not a desired SNP, so do not record anything
    
        #sys.stderr.write("Found a good snp!\n")
        # A desired SNP,
        
        # get species id 
        species_id = desired_snps[position]
        
        ref_base = fasta_file.fetch(reference=contig_name, start=pileupcolumn.pos, end=(pileupcolumn.pos+1))
       
        #print species_id, contig_name, contig_position, ref_base
       
        # so create a new entry
        snp = SNP(position)
        
        # and iterate over reads that pile up there
        for pileupread in pileupcolumn.pileups:
            
            aln = pileupread.alignment
            
            if stream_bam.compute_perc_id(aln) < args['mapid']:
                continue
            elif stream_bam.compute_aln_cov(aln) < args['aln_cov']:
                continue
            elif np.mean(aln.query_qualities) < args['readq']:
                continue
            elif aln.mapping_quality < args['mapq']:
                continue
            else:
                pass
                
            # get the barcode
            # Get name of read and split off barcode
            read_name = aln.query_name
            barcode_str = get_barcode_from_readname(args, read_name)
            
            # A compressed version of the barcode suitable for storage
            barcode_id = barcode_db.get_barcode_id(barcode_str)
            
            # figure out what the allele is (ref/alt/other)
            if pileupread.query_position == None:
                # a deletion or ref skip
                barcode_map  = snp.other_barcodes
            else:
              
                # enforce read quality filter
                if aln.query_qualities[pileupread.query_position]<args['baseq']:
                    continue
                
                if (aln.query_sequence[pileupread.query_position]==ref_base):
                    barcode_map = snp.ref_barcodes
                else:
                    barcode_map = snp.alt_barcodes
                    
            
            snp.reads+=1
            
            if barcode_id not in barcode_map:
                barcode_map[barcode_id] = 0
            
            barcode_map[barcode_id] += 1
            
        snps[species_id].append(snp)
        
        num_processed += 1
        if args['debug'] and (num_processed >= 100 or (num_processed >= len(desired_snps))):
            break 


def write_results(args, species, genes, desired_genes, snps, desired_snps, barcode_db):
    """ Write results to disk """
    # open outfiles for each species_id
    header = ['allele_id', 'barcodes']
    for sp in species.values():
        path = '/'.join([args['outdir'], args['barcode_dir'], 'output/%s.barcodes.gz' % sp.id])
        sp.out = utility.iopen(path, 'w')
        sp.out.write('\t'.join(header)+'\n')
    # write to output files
    # first do genes
    for gene_id in sorted(genes.keys()):
        gene = genes[gene_id]
        sp = species[gene.species_id]
        
        barcode_str = ",".join(["%s:%s" % (barcode_id, gene.barcodes[barcode_id]) for barcode_id in gene.barcodes.keys()])
        
        values = [gene.id, barcode_str]
        sp.out.write('\t'.join([str(_) for _ in values])+'\n')
    
    # then do snps
    for species_id in snps.keys():
        sp = species[species_id]
        
        for snp in snps[species_id]:
            
            # location str
            ref_location_str = "|".join([snp.position[0], str(snp.position[1]), "R"])
            barcode_str = ",".join(["%s:%s" % (barcode_id, snp.ref_barcodes[barcode_id]) for barcode_id in snp.ref_barcodes.keys()])
            values = [ref_location_str, barcode_str]
            sp.out.write('\t'.join([str(_) for _ in values])+'\n')
            
            alt_location_str = "|".join([snp.position[0], str(snp.position[1]), "A"])
            barcode_str = ",".join(["%s:%s" % (barcode_id, snp.alt_barcodes[barcode_id]) for barcode_id in snp.alt_barcodes.keys()])
            values = [alt_location_str, barcode_str]
            sp.out.write('\t'.join([str(_) for _ in values])+'\n')
            
                
    # close output files
    for sp in species.values():
        sp.out.close()
        
    # barcodes
    path = '/'.join([args['outdir'], args['barcode_dir'], 'output/all_barcodes.gz'])
    barcode_file = utility.iopen(path, 'w')

    header = ['barcode_id', 'barcode_str', 'count']
    barcode_file.write('\t'.join(header)+'\n')
    for barcode_id in barcode_db.get_all_barcode_ids():
        
        barcode_str, barcode_count = barcode_db.get_barcode_str_count(barcode_id)

        values = [barcode_id, barcode_str, barcode_count]
        barcode_file.write('\t'.join([str(_) for _ in values])+'\n')
    barcode_file.close()
        
    # summary stats
    path = '/'.join([args['outdir'], args['barcode_dir'], 'summary.txt'])
    file = open(path, 'w')
    header = ['species_id', 'pangenome_size', 'covered_genes', 'fraction_covered', 'mean_coverage', 'marker_coverage', 'count_reads']
    file.write('\t'.join(header)+'\n')
    #for sp in species.values():
    #   values = [sp.id, sp.pangenome_size, sp.covered_genes, sp.fraction_covered, sp.mean_coverage, sp.marker_coverage, sp.reads]
    #   file.write('\t'.join([str(_) for _ in values])+'\n')
    file.close()

def run_pipeline(args):
    """ Run entire pipeline """
    args['cluster_pid'] = 95
    # Initialize reference data
    print("\nReading reference data")
    start = time()
    barcode_db = BarcodeDB()
    species = initialize_species(args)
    load_gene_cluster_map(args,species)
    
    genes, desired_genes = initialize_genes(args, species)
    snps, desired_snps = initialize_snps(args, species)
    print("  %s minutes" % round((time() - start)/60, 2) )
    print("  %s Gb maximum memory" % utility.max_mem_usage())

    # Build pangenome database for selected species
    # Should already be built by the time the barcode module is called!

    # Use bowtie2 to align reads to pangenome database
    # Should already be aligned by the time the barcode module is called!
    
    # Compute pangenome coverage for each species
    if args['cov']:
        start = time()
        print("\nCollating barcodes")
        args['log'].write("\nCollating bardoes\n")
        barcode_coverage(args, species, genes, desired_genes, snps, desired_snps, barcode_db)
        print("  %s minutes" % round((time() - start)/60, 2) )
        print("  %s Gb maximum memory" % utility.max_mem_usage())


