#!/usr/bin/env python
# -*- coding: utf-8 -*-
import os
import logging
import colorama
from pybedtools import BedTool, chromsizes_to_file, chromsizes
from transpo import util


def run_stringtie(bam, args, gtf_name, sample_outdir):
    """Construct the full transcriptome of the BAM using stringtie"""
    stringtie = ("stringtie -j {junction_cov} -c {trans_cov} "
                 "-m 1000 -a 5 -t -p 4 -f 0.05 {bam} -l {lab} > {gtf_name}")

    util.logger.info("Finding all transcripts with Stringtie")
    util.logger.info("  Params: -j {j} -c {c}".format(j=args.junction_cov,
                                                      c=args.transcript_cov))
    util.logger.debug('Calling stringtie...')
    util.call_w_check(stringtie.format(bam=bam,
                                       prefix=sample_outdir,
                                       gtf_name=gtf_name,
                                       trans_cov=args.transcript_cov,
                                       junction_cov=args.junction_cov,
                                       lab=os.path.basename(gtf_name.replace(util.GTFEXT,
                                                                             ''))),
                      shell=True)


def get_transpo(args, gtf_name):
    """Given a gtf from stringtie, find transcripts that overlap
    both bed1 and bed2.
    """
    # check dep
    util.cmd_exists('bedtools')
    util.cmd_exists('gawk')
    # util.logger.info('Getting chimeric transcripts...'.format(bed1=args.bed1,
    #                                                           bed2=args.bed2))

    # Avoid regions that overlap in both beds
    if args.drop_inter is True:
        bed1_filt = BedTool(args.bed1).intersect(args.bed2, v=True)
    else:
        bed1_filt = BedTool(args.bed1)

    output_name = gtf_name.replace(util.GTFEXT, '_chimeric.gtf')
    sample_name = os.path.basename(gtf_name.replace(util.GTFEXT, ''))
    # unzip file
    tmpdir = '/{mydir}/{sample}'.format(mydir=args.tmpdir, sample=sample_name)
    if not os.path.exists(tmpdir):
        os.mkdir(tmpdir)
    gtf = tmpdir + '/{pre}.gtf'.format(pre=sample_name)
    util.call_w_check('gzip -dc {} > {}'.format(gtf_name, gtf))
    # get transcript name with first exon overlaping bed1
    if args.first_exon:
        get_exon_cmd = ("gawk '{{if ($3 == \"transcript\") {{if ($7 == \"+\") "
                        "pos=$4; else pos=$5}} else if "
                        "($5 == pos || $4 == pos) print $0;}}'")
        # util.logger.info("\tFirst exon overlapping {bed}...".format(bed=args.bed1))
    else:
        get_exon_cmd = "gawk '{{if ($3 == \"exon\") print $0}}'"
        # util.logger.info('\tAny exon overlapping {bed1}...'.format(bed1=args.bed1))
    # define some temp values
    tmp1 = tmpdir + "/{pre}_transpotmp1.txt".format(pre=sample_name)
    tmp2 = tmpdir + "/{pre}_transpotmp2.txt".format(pre=sample_name)
    tmp3 = tmpdir + "/{pre}_transpotmp3.txt".format(pre=sample_name)
    trans_filt1 = (get_exon_cmd
                   + (" {trans} | bedtools intersect -u -a - "
                      "-b {bed1} | gawk '{{ if (match($0, "
                      "/transcript_id \"([^;]*)\";/, m)) print m[1]}}' "
                      "| sort | uniq > {tmpfile}"))
    trans_filt1_f = trans_filt1.format(trans=gtf, bed1=bed1_filt.fn,
                                       trans_filt=output_name,
                                       tmpfile=tmp1)
    util.logger.debug(trans_filt1_f)
    util.call_w_check(trans_filt1_f, shell=True)

    # Get transcript name with any exon overlapping bed2 (strand specific)
    trans_filt2 = (get_exon_cmd + " {trans} | bedtools intersect -u -s -a - "
                   "-b {bed2} | gawk '{{ if "
                   "(match($0, /transcript_id \"([^;]*)\";/, m)) print m[1]}}'"
                   " | sort | uniq > {tmpfile}")
    # util.logger.info('\tAny exon overlapping {bed2}...'.format(bed2=args.bed2))
    trans_filt2_f = trans_filt2.format(trans=gtf, bed2=args.bed2,
                                       trans_filt=output_name,
                                       tmpfile=tmp2)
    util.logger.debug(trans_filt2_f)
    util.call_w_check(trans_filt2_f, shell=True)

    # Get common transcript betwenn the 2 previous lists
    grep_cmd1 = ("comm -12 {tmpfile1} {tmpfile2}"
                 " > {tmpfile3}")
    grep_cmd1_f = grep_cmd1.format(trans=gtf,
                                   trans_filt=output_name,
                                   tmpfile1=tmp1,
                                   tmpfile2=tmp2,
                                   tmpfile3=tmp3)
    grep_cmd2 = "LC_ALL=C grep -wF -f {tmpfile3} {trans} > {trans_filt}"
    grep_cmd2_f = grep_cmd2.format(trans=gtf,
                                   trans_filt=output_name,
                                   tmpfile3=tmp3)

    util.logger.debug(grep_cmd1_f)
    util.logger.debug(grep_cmd2_f)
    util.call_w_check(grep_cmd1_f, shell=True)
    util.call_w_check(grep_cmd2_f, shell=True)

    # clean tmp
    os.remove(tmp1)
    os.remove(tmp2)
    os.remove(tmp3)
    os.remove(gtf)
    os.rmdir(tmpdir)

    # Count the # of chimeric found
    with open(output_name) as out:
        count = 0
        for line in out:
            sp = line.strip().split('\t')
            try:
                if sp[2] == 'transcript':
                    count += 1
            except IndexError:
                util.logger.error(f'\n ERROR  found in {output_name}')
                import sys
                sys.ext()
    counts_str = colorama.Fore.GREEN + str(count) + colorama.Fore.RESET
    util.logger.info('\nFound {} transpochimeric transcripts for {}...'.format(counts_str,
                                                                               sample_name))
    if count == 0:
        util.call_w_check('rm {}'.format(output_name))
    util.call_w_check(f"gzip -f {output_name}")


def get_elem_tss(args, gtf_name):
    """Overlap the found trasncripts with the elem of bed1 to do analysis
    on the bed1 elem
    """
    # get the regions near tss
    util.logger.info(("\nGetting the elem of {bed1} overlapping TSS "
                      "regions".format(bed1=args.bed1)))
    util.logger.info("\tDist around TSS: " + colorama.Fore.GREEN
                     + "{}bp".format(args.tss_dist) + colorama.Fore.RESET)
    chimeric_name = gtf_name.replace(util.GTFEXT, '_chimeric.gtf')
    outname = gtf_name.replace(util.GTFEXT, '_elem_on_TSS.bed')
    if not os.path.exists('/tmp/genome.genome'):
        chromsizes_to_file(chromsizes(args.org), '/tmp/genome.genome')

    tss_overlap_cmd = ("gawk -v OFS='\t' '{{if ($3 == \"transcript\") "
                       "{{ if ($7 == \"+\") {{print $1,$4,$4+1,\"trans\",0,$7}}"
                       " else {{print $1,$5-1,$5,\"trans\",0,$7}}}}}}' "
                       "{chimeric_name} | bedtools slop -b {dist} -i - "
                       "-g /tmp/genome.genome | bedtools intersect -u -b - "
                       "-a {bed1} > {outname}")

    util.logger.debug(tss_overlap_cmd.format(chimeric_name=chimeric_name,
                                             bed1=args.bed1,
                                             dist=args.tss_dist, outname=outname,
                                             genome=args.org))

    util.call_w_check(tss_overlap_cmd.format(chimeric_name=chimeric_name,
                                             bed1=args.bed1,
                                             dist=args.tss_dist, outname=outname,
                                             genome=args.org), shell=True)
    util.logger.info('\tNumber of elements found: '
                     + colorama.Fore.GREEN + '{}\n'.format(len(BedTool(outname)))
                     + colorama.Fore.RESET)


def get_fasta_seq(args, gtf_name):
    """Get the fasta sequence of the elements overlapping the chimeric
    transcripts TSS
    """
    util.logger.info("Getting fasta sequences...")
    bedTss = gtf_name.replace(util.GTFEXT, '_elem_on_TSS.bed')
    output = bedTss.replace('.bed', '.fa')
    bed2fasta_cmd = ("bedtools getfasta -s -fi {genome_fasta} "
                     "-bed {bed} -fo {output}")
    util.call_w_check(bed2fasta_cmd.format(genome_fasta=args.genome_fasta,
                                           bed=bedTss, output=output), shell=True)


def main(bam, args, logger, sample_outdir, sample):
    """Main function of discovery. Find all the chimeric transcripts in the
    data and create the output gtfs."""
    # The last bit of the sample_outdir is the name of the sample
    gtf_name = os.path.join(sample_outdir, sample + util.GTFEXT)

    logfile = gtf_name.replace(util.GTFEXT, '.log')

    if os.path.isfile(logfile):
        logger.debug('Removing old logs...')
        os.remove(logfile)

    ch = logging.FileHandler(logfile)
    ch.setLevel('INFO')
    logger.addHandler(ch)

    # if os.path.isfile(gtf_name) and not args.force:
    #     logger.warning(("Stringtie file exists, skipping the assembly step "
    #                     "(use --force to force assembly)...\n"))
    if bam != 'nobam':
        run_stringtie(bam, args, gtf_name, sample_outdir)

    get_transpo(args, gtf_name)

    if args.genome_fasta != '':
        get_fasta_seq(args, gtf_name)
    logger.removeHandler(ch)
