#!/usr/bin/env python3

# Copyright (C) 2017 Tobias Jakobi
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

import pprint
import re
import subprocess
import tempfile
import argparse
import pybedtools


def return_bed_line(location, gene_name, comment, mode, threshold):
    split = location.split('_')

    split[1] = int(split[1])
    split[2] = int(split[2])

    if mode == 1 and split[2] > split[1] + threshold:
        # print(str(split[2]) +">"+ str(split[1]+ threshold) + " -> " + str(split[2] - split[1]))
        split[2] = split[1] + threshold

    if mode == 2 and split[1] < split[2] - threshold:
        split[1] = split[2] - threshold

    return split[0] + "\t" + \
           str(split[1]) + "\t" + \
           str(split[2]) + "\t" + \
           gene_name + "\t" + \
           "0" + "\t" + \
           split[3] + "\t" + \
           "none" + "\t" + \
           comment


def extract_start(string):
    split = string.split('_')
    return int(split[1])


def extract_chr(string):
    split = string.split('_')
    return split[0]


def extract_strand(string):
    split = string.split('_')
    return split[3]


def return_wobble(string, wobble):
    split = string.split('_')
    return split[0] + "_" + str(int(split[1]) + wobble) + "_" + split[2]


def extract_stop(string):
    split = string.split('_')
    return int(split[2])


def print_results(dcc_dict, repeat_dict):
    for entry in dcc_dict:

        stop = extract_stop(entry)
        start = extract_start(entry)
        chr = extract_chr(entry)
        strand = extract_strand(entry)

        start_flag = 0
        stop_flag = 0

        new_line = entry.replace("_", "\t").rstrip()

        addition_up = ""
        addition_down = ""

        for wobble in range(5, 1000000):

            if chr + "_" + str(start - wobble) + "_" + strand in repeat_dict:
                addition_down += "\t" + str(wobble)
                addition_down += "\t" + str(repeat_dict[chr + "_" + str(start - wobble) + "_" + strand])
                start_flag = 1
                break

        if start_flag == 0:
            addition_down += "\tNA\tNA"

        for wobble in range(5, 1000000):

            if chr + "_" + str(stop + wobble) + "_" + strand in repeat_dict:
                addition_up += "\t" + str(wobble)
                addition_up += "\t" + str(repeat_dict[chr + "_" + str(stop + wobble) + "_" + strand])
                stop_flag = 1
                break

        if stop_flag == 0:
            addition_up += "\tNA\tNA"

        if (entry[3] == "+"):
            new_line += addition_up + addition_down
        else:
            new_line += addition_down + addition_up

        print(new_line)
    return


def parse_dcc_file(input_file):
    from collections import OrderedDict
    loc_list = OrderedDict()

    with open(input_file) as fp:

        for line in fp:

            # make sure we remove the header
            if line.startswith('Chr'):
                continue

            # make sure we remove the header
            if line.startswith('track'):
                continue

            line = line.rstrip()
            current_line = line.split('\t')
            loc = current_line[0] + "_" + \
                  current_line[1] + "_" + \
                  current_line[2] + "_" + \
                  current_line[5]

            if loc not in loc_list:
                loc_list[loc] = current_line[3]

    return loc_list


def generate_virtual_region_downstream(circid):
    data = circid.split('_')
    return pybedtools.BedTool(
        data[0] + "\t" + str(int(data[1]) - 2000) + "\t" + data[1] + "\t" + data[1] + "\t0\t" + data[3],
        from_string=True)


def generate_virtual_region_upstream(circid):
    data = circid.split('_')
    return pybedtools.BedTool(
        data[0] + "\t" + data[2] + "\t" + str(int(data[2]) + 2000) + "\t" + data[2] + "\t0\t" + data[3],
        from_string=True)


def run_bedtools(alu_fh, virtual_region):
    result = alu_fh.intersect(virtual_region)
    return result


def extract_sequences(bed_input, fasta):
    from collections import OrderedDict
    return_dict = OrderedDict()

    tmp = bed_input.sequence(fi=fasta, tab=True)

    with open(tmp.seqfn) as fp:
        for line in fp:
            line = line.rstrip()
            current_line = line.split('\t')
            return_dict[current_line[0]] = current_line[1]

    return return_dict


def run_needle(up_seq, down_seq, file):

    file = file.replace("/", "_")

    up_tmp = open("/tmp/"+file+"_up.txt", "w")

    up_tmp.write(up_seq)

    from Bio.Seq import Seq
    from Bio.Alphabet import IUPAC
    my_seq = Seq(down_seq, IUPAC.unambiguous_dna)

    down_tmp = open("/tmp/"+file+"_down.txt", "w")

    # down_tmp = tempfile.NamedTemporaryFile()
    down_tmp.write(str(my_seq.reverse_complement()))

    upfile = "/tmp/"+file+"_up.txt"

    downfile = "/tmp/"+file+"_down.txt"

    resultfile = "/tmp/"+file+"_needle.txt"


    down_tmp.close()
    up_tmp.close()


    subprocess.run(['needle', '-asequence', downfile , '-bsequence', upfile, '-gapopen', '10.0',
                    '-gapextend', '0.5', '-outfile', resultfile, '-brief', 'Y'],
                   stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)


    file = open(resultfile, "r")
    for line in file:
        # m = re.search("Score: (.*)", line)
        m = re.search("Similarity:\W+\d+/\d+ \(\W*(.*)\%\)", line)
        if m:
            return float(m.group(1))


def get_repeats_for_circle(circID):
    # needs to return two lists: up and downstream

    return


def select_best_repeat_pair_for_circle(circID):
    # gets lists of up/down repeats and outputs best pair: score + dist

    return


def parse_bed_file(input_file):
    from collections import OrderedDict
    start_list = OrderedDict()

    with open(input_file) as fp:
        for line in fp:
            current_line = line.split('\t')

            current_line[0] = current_line[0].replace("chr", "")

            # stop point of previous exon
            stop = current_line[0] + "_" + str(int(current_line[1])) + "_" + current_line[5].rstrip()

            # start point of next exon
            start = current_line[0] + "_" + str(int(current_line[2])) + "_" + current_line[5].rstrip()

            loc = current_line[0] + "_" + \
                  current_line[1] + "_" + \
                  current_line[2] + "_" + \
                  current_line[5]

            # create key
            start_list[stop] = int(int(current_line[2]) - int(current_line[1]))

            start_list[start] = int(int(current_line[2]) - int(current_line[1]))

    return start_list


# main script starts here

parser = argparse.ArgumentParser(description='Create ')

group = parser.add_argument_group("Input")

group.add_argument("-d",
                   "--dcc-file",
                   dest="dcc_file",
                   help="CircCoordinates file from DCC",
                   required=True
                   )

group.add_argument("-b",
                   "--bed-file",
                   dest="bed_file",
                   help="BED file with all repeats",
                   required=True
                   )

group.add_argument("-t",
                   "--threshold",
                   dest="base_threshold",
                   help="max distance of circRNA to repeat start",
                   type=int,
                   default=1000000
                   )

group.add_argument("-g",
                   "--genome-file",
                   dest="genome_file",
                   help="Genome file fasta",
                   required=True
                   )

group.add_argument("-o",
                   "--out-file",
                   dest="out_file",
                   help="result file",
                   required=True
                   )


args = parser.parse_args()

dcc_input = parse_dcc_file(args.dcc_file)

# gtf_input = parse_bed_file(args.bed_file)

repeat_fh = pybedtools.BedTool(args.bed_file)

pp = pprint.PrettyPrinter(indent=4)

result_writer = open(args.out_file, "w")
for circ in dcc_input:

    # print("------> this is " + circ)

    tmp_fh_up = generate_virtual_region_upstream(circ)
    tmp_fh_down = generate_virtual_region_downstream(circ)

    from collections import OrderedDict

    data_down = OrderedDict()
    data_up = OrderedDict()
    result_dict = OrderedDict()

    circchr = circ.split('_')[0]
    circstart = int(circ.split('_')[1])
    circend = int(circ.split('_')[2])

    intersect_up = run_bedtools(repeat_fh, tmp_fh_up)
    if intersect_up:
        data_up = extract_sequences(intersect_up, args.genome_file)

    intersect_down = run_bedtools(repeat_fh, tmp_fh_down)
    if intersect_down:
        data_down = extract_sequences(intersect_down, args.genome_file)

    highest = -10
    highest_id = ""

    for seq_up in data_up:
        for seq_down in data_down:
            # print(seq_up+"->"+seq_down)
            score = run_needle(data_up[seq_up], data_down[seq_down], args.out_file)

            if score > 10 and score > highest:
                result_dict[seq_up + "_" + seq_down] = OrderedDict()
                result_dict[seq_up + "_" + seq_down]['score'] = score

                distup = int(seq_up.split(':')[1].split('-')[0]) - circend
                distdown = circstart - int(seq_down.split(':')[1].split('-')[1])
                result_dict[seq_up + "_" + seq_down]['dist'] = str(distdown) + "|" + str(distup)
                highest = score
                highest_id = seq_up + "_" + seq_down

    # print("result dict:")
    # print(len(result_dict))
    # print("best: "+highest_id + "-> " + str(highest))

    if len(result_dict) > 0:
        result_writer.write("\t".join([str(circchr),
                         str(circstart),
                         str(circend),
                         result_dict[highest_id]['dist'].split('|')[0],
                         result_dict[highest_id]['dist'].split('|')[1],
                         str(highest)])
              )
        result_writer.write("\n")
        result_writer.flush()

result_writer.close()