#!/usr/bin/env python

'''
Simple copy number profiling script for tumour / normal pairs.

Adam Ewing
adam.ewing@mater.uq.edu.au

'''

import pysam
import argparse
import multiprocessing as mp
import subprocess

import logging
logger = logging.getLogger(__name__)
FORMAT = '%(asctime)s %(message)s'
logging.basicConfig(format=FORMAT)
logger.setLevel(logging.INFO)

from math import log10


class Segment:
    def __init__(self, chrom, start, end, t_cpm, n_cpm):
        self.chrom = chrom
        self.start = start
        self.end   = end
        self.t_cpm = t_cpm
        self.n_cpm = n_cpm

        self.ratio    = 1.0
        self.logratio = 0.0

        if self.t_cpm > 0.0 and self.n_cpm > 0.0:
            self.ratio = self.t_cpm / self.n_cpm
            self.logratio = log10(self.t_cpm / self.n_cpm)

    def __lt__(self, other):
        if self.chrom != other.chrom:
            return self.chrom < other.chrom

        return self.start < other.start

    def __str__(self):
        return '%s\t%d\t%d\t%f\t%f\t%f\t%f' % (self.chrom, self.start, self.end, self.n_cpm, self.t_cpm, self.ratio, self.logratio)



def cpm(chrom, start, end, bamfn):
    bam = pysam.AlignmentFile(bamfn, 'rb')
    n = bam.mapped / float(1e6)

    count = 0

    for read in bam.fetch(chrom, start, end):
        if not read.is_secondary and read.mapq > 10: count += 1

    try:
        return (count / n) / (end-start)
    except ZeroDivisionError:
        return 0.0


def calc_seg(chrom, binstart, binend, tbam, nbam):
    normal_cpm = cpm(chrom, binstart, binend, args.normal)
    tumour_cpm = cpm(chrom, binstart, binend, args.tumour)

    return Segment(chrom, binstart, binend, tumour_cpm, normal_cpm)
 
def main(args):
    binsize = int(args.binsize)

    pool = mp.Pool(processes=int(args.procs))

    reslist = []

    with open(args.fai) as fai:
        for line in fai:
            chrom, chrlen = line.strip().split()[:2]
            chrlen = int(chrlen)

            for binstart in range(0, chrlen, binsize):
                binend = binstart + binsize
                if binend > chrlen: binend = chrlen

                res = pool.apply_async(calc_seg, [chrom, binstart, binend, args.tumour, args.normal])
                reslist.append(res)

    cn_segs = []
    for res in reslist:
        cn_segs.append(res.get())

    with open(args.outfile, 'w') as out:
        for s in sorted(cn_segs):
            out.write('%s\n' % str(s))


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='copy numper profile for t/n pair')
    parser.add_argument('-t', '--tumour', required=True, help='indexed BAM')
    parser.add_argument('-n', '--normal', required=True, help='indexed BAM')
    parser.add_argument('-f', '--fai', required=True, help='fasta index (.fai)')
    parser.add_argument('-b', '--binsize', required=True, help='bin size')
    parser.add_argument('-o', '--outfile', required=True)
    parser.add_argument('-p', '--procs', default=1)

    args = parser.parse_args()
    main(args)