#! /usr/bin/env python3.5

'''

    Takes fasta alignment file.
    Creates .ped and .map files to be used with Plink

    Pretends positions can be broken into SNPs like this:

    For position rd1 :

        rd1
    1 : A
    2 : T
    3 : T
    4 : C
    5 : -

    Becomes pseudo-biallelic at four pseudo-locations like so:

      rd1.A   rd1.T   rd1.G   rd1.C
    1:  1       0       0       0
    2:  0       1       0       0
    3:  0       1       0       0
    4:  0       0       0       1
    5:  0       0       0       0


'''

def readFasta(fname) :
    seqs = {}
    h = ''
    s = ''
    fh = open(fname, 'rt')
    for line in fh :
        line = line.strip()
        if line[0] == '>' :
            if len(s) != 0 :
                seqs[h] = s
            h = line.replace('>', '').strip()
            s = ''
        else :
            s = s + line.upper()
    if len(s) != 0 :
        seqs[h] = s
    return seqs

def posToPseudoAllelesExpand(snpMap, posNames, seqs, i) :
    snps = ['A','T','C','G']
    for allele in snps :
        posName = 'rd%d.%s' %(i, allele)
        posNames.append(posName)
        for h in snpMap :
            if seqs[h][i] == allele :
                snpMap[h].append('2')
            else :
                snpMap[h].append('1')
    return snpMap , posNames

def getMajor(snps, threshold) :
    assert threshold > 0.25
    total = len(snps)
    major = None
    if (1.0 * snps.count('A')) / total >= threshold :
        major = 'A'
    elif (1.0 * snps.count('C')) / total >= threshold :
        major = 'C'
    elif (1.0 * snps.count('T')) / total >= threshold :
        major = 'T'
    elif (1.0 * snps.count('G')) / total >= threshold :
        major = 'G'
    return major
    

def posToPseudoAllelesMajor(snpMap, posNames, seqs, i) :
    snps = [ seqs[h][i] for h in seqs ]
    isMulti = len(set(snps)) > 2
    major = None
    if isMulti :
        major = getMajor(snps, 1.0/2)
    posName = '%d' %(i)
    posNames.append(posName)
    for h in snpMap :
        if isMulti :
            if major :
                if seqs[h][i] == major:
                    snpMap[h].append(major)
                else :
                    snpMap[h].append('1')
            else :
                snpMap[h].append('0')
        else :
            snpMap[h].append(seqs[h][i])
    return snpMap , posNames

def seqsToPseudoAlleles(seqs, method) :
    posNames = []
    snpMap = {}
    for h in seqs :
        if h == 'consensus' :
            continue
        snpMap[h] = []
    for i in range(len(seqs['consensus'])) :
        if method == 'expand' :
            snpMap , posNames = posToPseudoAllelesExpand(snpMap, posNames, seqs, i)
        elif method == 'major' :
            snpMap , posNames = posToPseudoAllelesMajor(snpMap, posNames, seqs, i)
    return snpMap , posNames

def writeMap(posNames, ofname) :
    ofh = open(ofname, 'wt')
    for pn in posNames :
        ofh.write('TE %s 0\n' %pn)
    ofh.close()

def writePed(snpMap, ofname) :
    ofh = open(ofname, 'wt')
    for h in snpMap :
        line = '%s %s 0 0 0 0' %(h, h)
        for snp in snpMap[h] :
            line = line + ' %s %s' %(snp, snp)
        line = line + '\n'
        ofh.write(line)
    ofh.close()

def main(fname, outName, method) :
    seqs = readFasta(fname)
    print('Read %d sequences' %len(seqs.keys()))
    snpMap = None
    posNames = None
    snpMap , posNames = seqsToPseudoAlleles(seqs, method)
    writeMap(posNames, '%s.map' %outName)
    writePed(snpMap, '%s.ped' %outName)

if __name__ == '__main__' :
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('fasta')
    parser.add_argument('outName')
    parser.add_argument('--m', choices=['expand', 'major'], help='Options: expand means each position becomes four pseudo-alleles for the nucleotides, major means all multi-allele places are turned into a most-common allele vs. everything else allele.')
    args = parser.parse_args()
    main(args.fasta, args.outName, args.m)

