#! /usr/bin/env python3.5
#Edited by Alan 05/03/2019, changed consensus check in readPhenotypes to check if line has 'consensus' string

def readFasta(fname) :
    seqs = {}
    ID = ''
    seq = ''
    fh = open(fname, 'rt')
    for line in fh :
        line = line.strip()
        if line[0] == '>' :
            if len(seq) != 0 :
                seqs[ID] = seq
            ID = line[1:].strip()
            seq = ''
        else :
            seq = seq + line.upper()
    if len(seq) != 0 :
        seqs[ID] = seq
    return seqs

def readPhenotypes(phenFile) :
    ifh = open(phenFile, 'rt')
    header = ifh.readline().strip().split()
    phens = {}
    for i in range(1, len(header)) :
        phens[header[i]] = {
            'cases': [],
            'controls': []
        }
    for line in ifh :
        if 'consensus' in line:
            continue
        line = line.strip().split()
        ID = line[0]
        for i in range(1, len(header)) :
            if line[i] == '1' :
                phens[header[i]]['cases'].append(ID)
            elif line[i] == '0' :
                phens[header[i]]['controls'].append(ID)
    ifh.close()
    return phens

# We may only want to keep unique kmers in each sequence. Don't count a kmer twice if it was already seen in that sequence.
def generateKmers(seq, k) :
    kmers = set()
    for i in range(len(seq)-k) :
        kmer = seq[i:i+k]
        if kmer not in kmers :
            yield kmer
            kmers.add(kmer)

def getAllKmers(group, seqs, k) :
    counts = {}
    for ID in group :
        if ID not in seqs :
            continue
        for kmer in generateKmers(seqs[ID], k) :
            if kmer not in counts :
                counts[kmer] = 0
            counts[kmer] += 1
    return counts

def getKmerNums(cases, controls, seqs, k) :
    caseNums = getAllKmers(cases, seqs, k)
    controlNums = getAllKmers(controls, seqs, k)
    kmers = set(caseNums).union(set(controlNums))
    nums = {}
    diffs = {}
    ncase = len(cases)
    ncontrol = len(controls)
    for kmer in kmers :
        kCaN = caseNums[kmer] if kmer in caseNums else 0
        kCoN = controlNums[kmer] if kmer in controlNums else 0
        kCaF = (1.0 * kCaN) / ncase if kmer in caseNums else 0
        kCoF = (1.0 * kCoN) / ncontrol if kmer in controlNums else 0
        diffs[kmer] = abs(kCaF - kCoF)
        nums[kmer] = { 'cases' : kCaN, 'controls' : kCoN }
    return nums, diffs

def reportFreqDiff(diffs, ofh) :
    kmers = [ k for k in diffs.keys() ]
    kmers.sort(key = diffs.get, reverse=True)
    for i in range(5) :
        ofh.write(kmers[i])
        ofh.write(' , ')
        ofh.write(diffs[kmers[i]])
        ofh.write('\n')

# a = A1, b = A2, c = B1, d = B2
# If Xsq is 19 (with 1 df) then p-value is 1.3e-5. We don't care about any p-value higher than that, so don't even do the rest of the computing on those ones.
def getXsq(a,b,c,d) :
    if a+b != 0 and a+c != 0 and b+d != 0 and c+d != 0 :
        Xsq = ((a+b+c+d)*pow((b*c-a*d),2)) / ((a+b)*(a+c)*(b+d)*(c+d))
        return Xsq
    return -1

def reportNumDiffs(nums, phen, ncase, ncontrol, ofh) :
    kmers = [ k for k in nums.keys() ]
    for kmer in kmers :
        nKCa = nums[kmer]['cases']
        nKCo = nums[kmer]['controls']
        nKCaA = ncase - nKCa
        nKCoA = ncontrol - nKCo
        # if kmer == 'CTTGACTCG' and phen == 'K562_H3k27ac':
        #     print(nKCa, nKCaA)
        #     print(nKCo, nKCoA)
        Xsq = getXsq(nKCa, nKCaA, nKCo, nKCoA)
        ofh.write('%s\t%s\t%d\t%d\t%d\t%d\t%f\n' %(phen, kmer, nKCa, nKCo, nKCaA, nKCoA, Xsq))

def main(k, reportFlag, phenFile, fullFasta, outfile) :
    phens = readPhenotypes(phenFile)
    seqs = readFasta(fullFasta)
    ofh = open(outfile, 'wt')
    if reportFlag == 'c' :
        ofh.write('phen\tkmer\tcasePresent\tcontrolPresent\tcaseAbsent\tcontrolAbsent\tXsq\n')
    for phen in phens :
        nums, diffs = getKmerNums(phens[phen]['cases'], phens[phen]['controls'], seqs, k)
        ncase = len(phens[phen]['cases'])
        ncontrol = len(phens[phen]['controls'])
        if reportFlag == 'f' :
            ofh.write('%s : %d cases, %d controls\n' %(phen, ncase, ncontrol))
            reportFreqDiff(diffs, ofh)
        if reportFlag == 'c' :
            reportNumDiffs(nums, phen, ncase, ncontrol, ofh)
    ofh.close()

if __name__ == '__main__' :
    from argparse import ArgumentParser
    parser = ArgumentParser()
    parser.add_argument('k', type=int)
    parser.add_argument('phenFile', help='file with phenotypes')
    parser.add_argument('fullFasta', help='file with sequences')
    parser.add_argument('out', help='filename for output')
    parser.add_argument('-report', choices=['f', 'c'], default='c', help='Report option, use f for frequency report, use c for count report')
    args = parser.parse_args()
    main(args.k, args.report, args.phenFile, args.fullFasta, args.out)

