#!/usr/bin/env python

import anduril
from anduril.args import mappings, unmapped, overmapped, errors_profile, hits_distribution, errors_distribution, logfile
import numpy as np
import matplotlib.pyplot as plt
import gzip

def return_read_num(fi):

    c = 0

    with open(fi) as f:
        while True:
            l1 = f.readline()
            l2 = f.readline()

            if not l1 or not l2:
                break

            l = l1.split()
            c += int(l[1])

    return c


def main():
    """
    tranform component gives a unique fasta file that contains the copy number of each read in the header.
    This is the input that should be used by bowtie.
    The output of bowtie then looks like:

    FC30P42HM_20081212:6:3:487:1935 1       +       chr2    151774464       AGGAAAAAAGGCTTTGTTTTTGTGCAAA    IIIIIIIIIIIIIIIIIIIIIIIIIIII    0       
    FC30P42HM_20081212:6:29:245:121 1       -       chr16   2584251 CAGCTATAAGGTCAGCTGATTAATT       IIIIIIIIIIIIIIIIIIIIIIIII       0       
    FC30P42HM_20081212:6:50:519:804 1       +       chr3    82943025        GAGACAGACAGAGAGAACGAAAGAGACA    IIIIIIIIIIIIIIIIIIIIIIIIIIII    2       17:A>C
    FC30P42HM_20081212:6:50:519:804 1       +       chr1    195109400       GAGACAGACAGAGAGAACGAAAGAGACA    IIIIIIIIIIIIIIIIIIIIIIIIIIII    2       17:A>C
    FC30P42HM_20081212:6:50:519:804 1       -       chrUn_gl000220  127285  TGTCTCTTTCGTTCTCTCTGTCTGTCTC    IIIIIIIIIIIIIIIIIIIIIIIIIIII    2       0:G>C
    FC30P42HM_20081212:6:61:1007:1744 1     +       chr2    162965671       GTTTCAGGCCATGGGTCAAGCAGCAT      IIIIIIIIIIIIIIIIIIIIIIIIII      0       
    FC30P42HM_20081212:6:8:179:451 1        -       chrX    123143603       AGGTTTCAATGACAGTTTTCATACTCTA    IIIIIIIIIIIIIIIIIIIIIIIIIIII    0       
    FC30P42HM_20081212:6:10:352:1947 2      -       chr4    94645692        ATGTAATTTCTCTTCTCTATGTCCTTC     IIIIIIIIIIIIIIIIIIIIIIIIIII     0       
    FC30P42HM_20081212:6:59:1174:1599 1     -       chr4    15142400        GTCCTCAGTCCTGCATTTATGCCTTTTC    IIIIIIIIIIIIIIIIIIIIIIIIIIII    0     


    1. column: read name (== fasta header)
    2. column: strand
    3. column: chrom
    4. column: start (0 based indexing)
    5. column: read sequence (not alignment sequence!). If aligned to - strand it is the reverse complement
    6. column: Quality score
    7. column: n+1 is the number of positions the read mapped to (hits).
    8. column: in case there were mismatches they are reported here.

    To extract bedweight format weight each mapping with copy_numer/mappings, i.e.:
    (number in column 1) / ((number in column 7) + 1)

    I create 3 error plots (as in createErrorPlots of fmitools):
    1. Percentage of reads that had an error at read position i.
    2. Stacked barplot (cumulative histogram) of reads with 0 hits (unmampped), reads with 1 hit, reads with 5, 10, 50, 100 or more hits.
    3. Some sort of reverse cumualtive that shows the percentage of reads with 1, 2 or 3 errors (mismatches) up to position i. It also shows unmapped and multimapped (>100) reads.

    For plot 3 and 1 I create 3 lists that have maximum read length.
       - First list counts for each position the number of first errors. 
       - For reads with 2 errors the second error is added in the second list.
       - For reads with 3 errros the third error is added in the third list.

    For plot 2 I just do a histogram of the number of hits. I make a dict with the first column of the bowtie output as key and the number of hits as value.

    Always, to count reads I use the weight copy_numer/mappings.
    """

    # mappings, unmapped, overmapped, errors_profile, hits_distribution, errors_distribution, logfile

    # initialize with 1000 to be sure that longer reads don't result in problems
    first_errors = np.zeros(1000)
    second_errors = np.zeros(1000)
    third_errors = np.zeros(1000)

    hits_dict = {} # read identifiers are key, number of hits value
    readlen_dict = {}  # lenghts are keys number of reads are values.


    gzipped = True
    try:
        fin = gzip.open(mappings)
        fin.readline()
    except IOError:
        gzipped = False
    fin.close()


    if gzipped:
        fin = gzip.open(mappings)
    else:
        fin = open(mappings)

    for line in fin:
        l = line.strip().split('\t')

        l0 = l[0].split(' ')
        read_name = l0[0]
        copy_num = float(l0[1])

        hits = float(l[6]) + 1

        # count number of hits for each unique read
        try:
            h = hits_dict[l[0]]
        except KeyError:
            hits_dict[l[0]] = hits

        weight = copy_num/hits

        readlen = len(l[4])
        try:
            readlen_dict[readlen] += weight
        except KeyError:
            readlen_dict[readlen] = weight

        # check for mismatches/errors:
        # format:   2:T>A,11:T>C,23:C>A
        # indexing starts at 0
        try:
            errors = l[7]
            errors = errors.strip().split(',')

            for i, e in enumerate(errors):
                pos = int(e.split(':')[0])
                if i == 0:
                    first_errors[pos] += weight
                elif i == 1:
                    second_errors[pos] += weight
                elif i == 2:
                    third_errors[pos] += weight
                else:
                    print 'Error: More than 3 mismatches in ', line
                    sys.exit(1)

        except IndexError:
            pass

    fin.close()

    tot_reads_mapped = int(sum(readlen_dict.values()))
    max_readlen = max(readlen_dict.keys())

    first_errors = first_errors[:max_readlen]
    second_errors = second_errors[:max_readlen]
    third_errors = third_errors[:max_readlen]

    cumsum_1 = np.cumsum(first_errors)
    cumsum_2 = np.cumsum(second_errors)
    cumsum_3 = np.cumsum(third_errors)

    # create an array that that contains the number of reads for each readlength (cumulative)
    read_num_cum = np.zeros(max_readlen)
    for i in sorted(readlen_dict.keys(), reverse=True):
        read_num_cum[:i] += readlen_dict[i]

    no_errors = read_num_cum
    no_errors -= cumsum_1
    no_errors -= cumsum_2
    no_errors -= cumsum_3

    num_unmapped = return_read_num(unmapped)
    num_overmapped = return_read_num(overmapped)
    tot_reads = tot_reads_mapped + num_unmapped + num_overmapped


    ## Plot fraction of reads with certain error number plot
    plt.figure()
    ax = plt.axes()
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)
    ax.spines['bottom'].set_position(('outward', 10)) # move axis down a little
    ax.xaxis.set_ticks_position('bottom')
    ax.yaxis.set_ticks_position('left')

    box = ax.get_position()
    ax.set_position([box.x0, box.y0*1.5, box.width, box.height*0.85])

    cm = plt.get_cmap('Greens') # colors can be called by cm(i) where i is a number between 0 (white) to 1 (dark)

    colors = [cm(0.6)]

    cm = plt.get_cmap('jet') # colors can be called by cm(i) where i is a number between 0 (white) to 1 (dark)

    colors += [cm(0.7), cm(0.35), cm(0.1), cm(0.9), cm(1.0)]

    b = np.ones(max_readlen)
    x = range(1, max_readlen+1, 1)
    plt.bar(x, (no_errors/tot_reads)*100, 
            align='center', color=colors[0], label='0err', linewidth=0, width=0.95)

    plt.bar(x, (cumsum_1/tot_reads)*100, bottom = (no_errors/tot_reads)*100,
            align='center', color=colors[1], label='1err', linewidth=0, width=0.95)

    plt.bar(x, (cumsum_2/tot_reads)*100, bottom = ((no_errors + cumsum_1)/tot_reads)*100, 
            align='center', color=colors[2], label='2err', linewidth=0, width=0.95)

    plt.bar(x, (cumsum_3/tot_reads)*100, bottom = ((no_errors + cumsum_1 + cumsum_2)/tot_reads)*100, 
            align='center', color=colors[3], label='3err', linewidth=0, width=0.95)

    plt.bar(x, ((b*num_overmapped)/tot_reads)*100, bottom = ((no_errors + cumsum_1 + cumsum_2 + cumsum_3)/tot_reads)*100, 
            align='center', color=colors[4], label='ambiguous', linewidth=0, width=0.95)

    plt.bar(x, ((b*num_unmapped)/tot_reads)*100, bottom = ((no_errors + cumsum_1 + cumsum_2 + cumsum_3 + num_overmapped)/tot_reads)*100, 
            align='center', color=colors[5], label='unmapped', linewidth=0, width=0.9)

    plt.ylim([0,100])
    plt.legend(frameon=False, ncol=3, loc='center', bbox_to_anchor=(0.5, 1.1))
    plt.xlabel('Read Position')
    plt.ylabel('Percentage of all Reads')
    plt.savefig(errors_distribution)
    plt.close()



    ## Plot stacked bar chart that counts number of reads with up to a certain number of hits
    hits_binned = [0, 0, 0, 0, 0] # Counts reads with 1, 1-5, 5-10, 10-50 and 50-100 hits

    for i in hits_dict:
        t = i.split(' ')
        copy_num = int(t[1])

        h = hits_dict[i]
        if h > 0.5 and h < 1.9:
            hits_binned[0] += copy_num
        elif h > 1.9 and h < 5.9:
            hits_binned[1] += copy_num            
        elif h > 5.9 and h < 10.9:
            hits_binned[2] += copy_num            
        elif h > 10.9 and h < 50.9:
            hits_binned[3] += copy_num            
        elif h > 50.9 and h < 100.9:
            hits_binned[4] += copy_num            

    counts = [num_unmapped] + hits_binned + [num_overmapped]

    plt.figure(figsize=(12,2))
    ax = plt.axes()
    #ax.set_frame_on(False)
    #ax.axes.get_yaxis().set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)
    ax.spines['left'].set_visible(False)
    ax.spines['bottom'].set_position(('outward', 10)) # move axis down a little

    ax.xaxis.set_ticks_position('bottom')
    ax.xaxis.get_major_formatter().set_powerlimits((0, 1)) # use scientific notation

    # turn off ticks and labels on the y axis
    ax.yaxis.set_ticks([])
    ax.yaxis.set_ticklabels([])

    # Shrink current axis by 20%
    box = ax.get_position()
    ax.set_position([box.x0, box.y0*3.5, box.width, box.height*0.3])

    cm = plt.get_cmap('Greens') # colors can be called by cm(i) where i is a number between 0 (white) to 1 (dark)
    b = np.linspace(0.2, 1, len(counts)-2)[::-1]

    legends = ['0', '1', '5', '10', '50', '100', '>100']
    for i in [0, 1, 2, 3, 4, 5, 6]:
        if i == 0:
            c = 'r'
        elif i == 6:
            c = '0.75' # grey shade
        else:
            c = cm(b[i-1])
        plt.barh(1, counts[i], left = sum(counts[:i]), color=c, linewidth=0, align='center', label=legends[i])

    plt.xlabel('Number of Reads', fontsize=14)
    plt.title('Number of Hits', position=(0.5, 1.8), fontsize=14)
    plt.legend(frameon=False, ncol=len(counts), loc = 'center', bbox_to_anchor = (0.5, 1.5))
    # plt.tight_layout() # this does not work with plt.title. It is a known bug
    plt.savefig(hits_distribution)
    plt.close()



    ## Plot error profile, i.e. Percentage of reads that had a sequencing error at position i.
    tot_errors = first_errors + second_errors + third_errors
    plt.figure()
    plt.plot(x, (tot_errors/tot_reads_mapped)*100, linewidth=3, color=cm(0.45))
    plt.xlabel('Read Position')
    plt.ylabel('Sequencing Errors (%)')
    plt.savefig(errors_profile)
    plt.close()


    ## Write some log file
    with open(logfile, 'w') as o:
        o.write('Total reads mapped:\t%i\n' %tot_reads_mapped)
        o.write('Total unmapped reads:\t%i\n' %num_unmapped)
        o.write('Total overmapped reads (>100 hits):\t%i\n' %num_overmapped)




    return 0


main()
