#! /usr/bin/env python3.5

###########  OVERLAP ALIGNMENT  #############

indel_penalty = -9999999999999

def score_matrix(a, b) :
    global indel_penalty
    if a == b :
        return 1
    else :
        return indel_penalty

def backup(back, seq1, seq2) :
    res1 = []
    res2 = []
    i = len(seq1)
    j = len(seq2)
    while i > 0 or j > 0 :
        i2, j2 = back[i][j]
        if i2 == i :
            res1.append('-')
        else :
            res1.append(seq1[i2])
        if j2 == j :
            res2.append('-')
        else :
            res2.append(seq2[j2])
        i = i2
        j = j2
    return ''.join(res1[::-1]) , ''.join(res2[::-1])

# Down is free when j == 0 (before beginning of seq2)
# Right is free when i == len(seq1) (after end of seq1)
def update(sm, i, j, seq1, seq2) :
    global indel_penalty
    options = {}
    # Try diagonal
    if i > 0 and j > 0 :
        options[ sm[i-1][j-1] + score_matrix(seq1[i-1], seq2[j-1]) ] = (i-1, j-1)
    # Try down
    if i > 0 :
        # We prefer mismatches to indels for this kind of alignment, I think.
        # So, I don't replace the mismatch if its score already exists.
        sc = sm[i-1][j] + indel_penalty
        if sc not in options :
            options[ sc ] = (i-1, j)
    # Down is free in this case
    if j == 0 :
        sc = sm[i-1][j]
        if sc not in options :
            options[ sc ] = (i-1, j)
    # Try right
    if j > 0 :
        sc = sm[i][j-1] + indel_penalty
        if sc not in options :
            options[ sc ] = (i, j-1)
    # Right is free in this case
    if i == len(seq1) :
        sc = sm[i][j-1]
        if sc not in options :
            options[ sc ] = (i, j-1)
    best_score = max(options)
    return best_score , options[best_score]

def trim(res1, res2) :
    while len(res1) > 0 and res1[-1] == '-' :
        res1 = res1[:-1]
        if len(res2) > 0 :
            res2 = res2[:-1]
    while len(res2) > 0 and res2[0] == '-' :
        if len(res1) > 0 :
            res1 = res1[1:]
        res2 = res2[1:]
    return res1 , res2

def overlap_alignment(seq1, seq2) :
    sm = [[0 for j in range(len(seq2) + 1)] for i in range(len(seq1) + 1)]
    back = [[(0,0) for j in range(len(seq2) + 1)] for i in range(len(seq1) + 1)]
    for i in range(len(seq1)+1) :
        for j in range(len(seq2)+1) :
            if i == 0 and j == 0 :
                sm[i][j] = 0
                continue
            sm[i][j] , back[i][j] = update(sm, i, j, seq1, seq2)
    res1 , res2 = backup(back, seq1, seq2)
    res1 , res2 = trim(res1, res2)
    return res1, res2, sm[len(seq1)][len(seq2)]

def getPerfectOverlap(seq1, seq2) :
    res1, res2, score = overlap_alignment(seq1, seq2)
    if res1 != res2 :
        return ''
    else :
        return res1

#########   MAIN SCRIPT   ###########

class Contig :
    def __init__(self, _contig = []) :
        self.contig = _contig
    # Returns overlap sequence and score sequence
    def toSeq(self, kmerPvals) :
        seq = self.contig[0]
        logPvals = {}
        for i in range(len(seq)) :
            logPvals[i] = [kmerPvals[seq]]
        for kmer in self.contig[1:] :
            overlap = getPerfectOverlap(seq, kmer)
            pos = len(seq) - len(overlap)
            for i in range(len(kmer)) :
                if pos + i == len(seq) :
                    seq = seq + kmer[i]
                if pos + i not in logPvals :
                    logPvals[pos+i] = []
                logPvals[pos+i].append(kmerPvals[kmer])
            pos += 1
        #print('%d %d' %(len(seq), len(logPvals.keys())))
        scores = []
        for i in range(len(seq)) :
            score = (1.0 *sum(logPvals[i])) / len(logPvals[i])
            scoreChr = chr(48 + int(score))
            #print(i, logPvals[i], score, scoreChr)
            scores.append(scoreChr)
        return seq, ''.join(scores)
    # Creates new merged version, does not modify existing.
    # requiredIdentity is the amount of exact match (length of overlap) required
    def merge(self, other, kmerPvals, requiredIdentity = 3) :
        seqS, scoresS = self.toSeq(kmerPvals)
        seqO, scoresO = other.toSeq(kmerPvals)
        overlap1 = getPerfectOverlap(seqS, seqO)
        overlap2 = getPerfectOverlap(seqO, seqS)
        if len(overlap1) < requiredIdentity and len(overlap2) < requiredIdentity :
            return []
        c1 = []
        c2 = []
        OVERLAP = ''
        seq1 = ''
        seq2 = ''
        # If better to overlap self on left
        if len(overlap1) > len(overlap2) :
            OVERLAP = overlap1
            seq1 = seqS
            seq2 = seqO
            c1 = [k for k in self.contig]
            c2 = [k for k in other.contig]
        # If better to overlap other on left
        else :
            OVERLAP = overlap2
            seq1 = seqO
            seq2 = seqS
            c1 = [k for k in other.contig]
            c2 = [k for k in self.contig]
        if len(c1) == 0 and len(c2) == 0 :
            return []
        cEND = [c1.pop(0)]
        while len(c1) + len(c2) > 0 :
            # Base cases: only one list left, so absorb the rest of the list
            if len(c1) == 0 :
                cEND = cEND + c2
                c2 = []
                break
            if len(c2) == 0 :
                cEND = cEND + c1
                c1 = []
                break
            # Look for the next possible overlap
            p1 = ''
            p2 = ''
            nc1 = ''
            nc2 = ''
            if len(c1) > 0 :
                nc1 = c1[0]
                p1 = getPerfectOverlap(cEND[-1], nc1)
            if len(c2) > 0 :
                nc2 = c2[0]
                p2 = getPerfectOverlap(cEND[-1], nc2)
            if len(p1) > len(p2) :
                cEND.append(c1.pop(0))
            elif len(p1) < len(p2) :
                cEND.append(c2.pop(0))
            else :
                print(self.contig, seqS)
                print(other.contig, seqO)
                print('HEY! YOU SHOULD NEVER GET TWO KMERS OVERLAPPING THE SAME AMOUNT LIKE THIS!\nsequences: %s , %s\noverlap: %s\nchoices: %s, %s\nperfect overlaps: %s, %s' %(seq1, seq2, OVERLAP, nc1, nc2, p1, p2))
                print('\n\n\n\n          FAILED\n\n\n\n')
                import sys
                sys.exit(1)
        return cEND

def readKmers(fname) :
    ifh = open(fname, 'rt')
    header = ifh.readline().strip().split()
    i_phen = header.index('phen')
    i_kmer = header.index('kmer')
    i_P_chi = header.index('P_chi')
    i_log_P_chi = header.index('log_P_chi')
    kmersByPhen = {}
    for line in ifh :
        line = line.strip().split()
        if line[i_phen] not in kmersByPhen :
            kmersByPhen[line[i_phen]] = {}
        kmersByPhen[line[i_phen]][line[i_kmer]] = float(line[i_log_P_chi])
    ifh.close()
    return kmersByPhen

def mergeContigs(contigs, requiredIdentity, kmers) :
    for i in range(len(contigs)-1) :
        for j in range(i+1, len(contigs)) :
            merged = contigs[i].merge(contigs[j], kmers, requiredIdentity)
            if len(merged) == 0 :
                merged = contigs[j].merge(contigs[i], kmers, requiredIdentity)
            if len(merged) > 0 :
                contigs.pop(j)
                contigs.pop(i)
                contigs.append(Contig(merged))
                return contigs
    return contigs

def assembleKmers(kmers) :
    contigs = [ Contig([k]) for k in kmers ]
    if len(contigs) == 0 :
        return contigs
    #print('%d kmers' %len(contigs))
    k = len(contigs[0].contig[0])
    requiredIdentity = k - 1
    keepGoing = True
    lastChance = False
    minOverlap = max(k-5, 3)
    while keepGoing or lastChance :
        #print(requiredIdentity)
        before = len(contigs)
        contigs = mergeContigs(contigs, requiredIdentity, kmers)
        after = len(contigs)
        if before == after :
            requiredIdentity -= 1
        keepGoing = before != after or requiredIdentity >= minOverlap
        if not keepGoing and lastChance:
            lastChance = not lastChance
    return contigs


# We write to a modified sort of fasta format only with scores
# The format is such:
#
# > HEADER LINE FOR OVERLAPPED KMERS
# 12222999aaaaaa999222
# > NEXT HEADER LINE ...
#
# The numbers correspond to the average -log10(pvalue), using ASCII values through python's chr() + 48 function (use ord() - 48 to get back to number), at that position.
def writeContigs(fname, contigsByPhen, kmersByPhen) :
    seq_ofname = fname[:-4] + '.fastk'
    seq_ofh = open(seq_ofname, 'wt')
    for phen in contigsByPhen :
        contigs = contigsByPhen[phen]
        i = 1
        for contig in contigs :
            seq, scores = contig.toSeq(kmersByPhen[phen])
            seq_ofh.write('> %s REGION %d\n%s\n+\n%s\n' %(phen, i, seq, scores))
            i += 1
    seq_ofh.close()
    return

# handle all phenotypes and code the phenotype into the header of the two output files.
def main(fname) :
    kmersByPhen = readKmers(fname)
    contigsByPhen = {}
    for phen in kmersByPhen :
        contigsByPhen[phen] = assembleKmers(kmersByPhen[phen])
    hasContigs = False
    for phen in contigsByPhen :
        if len(contigsByPhen[phen]) > 0 :
            hasContigs = True
            break
    if hasContigs :
        writeContigs(fname, contigsByPhen, kmersByPhen)

if __name__ == '__main__' :
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('fname', help='Filename of input, as output from analyzeKmerCounts.R usually named \"important#.txt\"')
    args = parser.parse_args()
    main(args.fname)

