#!/usr/bin/env python

# Copyright 2009, 2010 Yann Surget-Groba
# This file is part of the STM pipeline.
#
# The STM pipeline is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# The STM pipeline is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this file. If not, see <http://www.gnu.org/licenses/>.


import Bio
from Bio.Align import AlignInfo
from Bio import SeqIO
from Bio.SeqRecord import SeqRecord
from Bio.Seq import Seq

class Hit:
    def __init__(self, name, start, end, strand):
        self.name   = name
        self.start  = int(start)
        self.end    = int(end)
        self.strand = strand


def consensus(align, ambCodes=0, threshold=0.3):
    summary = AlignInfo.SummaryInfo(align)
    maxLength = align.get_alignment_length()
    consensus = ''
    score = 0
    pssm = summary.pos_specific_score_matrix(chars_to_ignore=['N','X', '-'])
    
    for position in range(maxLength):
        countBases = {}
        nBases = 0
        for base in pssm[position]:
            if pssm[position][base] > 0:
                countBases[base] = pssm[position][base]
                nBases += pssm[position][base]
        if len(countBases) == 0:
            consBase = 'N'
        elif len(countBases) == 1:
            consBase = countBases.keys()[0]
        else:
            score += 1
            bases = []
            for b in countBases:
                if float(countBases[b])/nBases >= threshold:
                    bases.append(b)
            if len(bases) == 1:
                consBase = bases[0]
            else:
                if ambCodes:
                    bases.sort()
                    if bases == ['C','T']:
                        consBase = 'Y'
                    elif bases == ['A','G']:
                        consBase = 'R'
                    elif bases == ['A','T']:
                        consBase = 'W'
                    elif bases == ['C','G']:
                        consBase = 'S'
                    elif bases == ['G','T']:
                        consBase = 'K'
                    elif bases == ['A','C']:
                        consBase = 'M'
                    elif bases == ['A','G', 'T']:
                        consBase = 'D'
                    elif bases == ['A', 'C','G']:
                        consBase = 'V'
                    elif bases == ['A','C','T']:
                        consBase = 'H'
                    elif bases == ['C','G','T']:
                        consBase = 'B'
                    else:
                        consBase = 'N'
                else:
                    consBase = 'N'
        consensus += consBase
    return (consensus, score)


def mapBlastHit(queryStart, subjectStart, strand, lengthRead):
    if strand == '+':
        mapPosition = ((subjectStart-1)*3+1) - (queryStart -1)
    if strand == '-':
        mapPosition = ((subjectStart-1)*3+1) - (lengthRead-queryStart)
    return mapPosition


def indexSeqFile(handle, format='fasta', sortByNumber=False):
    seqDic = {}
    if format == 'fasta':
        char = '>'
    elif format == 'fastq':
        char = '@'
    else:
        return 0
    position = 0
    handle.seek(0)
    nSeq = 0
    while 1:
        line = handle.readline()
        if line == '':
            break
        if line[0] == char:
            nSeq += 1
            description = line.strip()[1:]
            if sortByNumber:
                seqDic[nSeq] = position
            else:
                seqDic[description] = position
        position = handle.tell()
    return seqDic
        

def getSeqByPosition(handle, position=0, format='fasta'):
    handle.seek(position)
    seqIterator = SeqIO.parse(handle, format)
    seqRecord = seqIterator.next()
    seqIterator.close()
    return seqRecord

def cleanSeq(seqRecord):
    s=''
    for base in seqRecord:
        if base not in ['A','C','G','T','N']:
            break
        s+=base
    s = Seq(s)
    seqRecord.seq = s
    return seqRecord

def parseCapRes(handle):
    contigs = {}
    handle.seek(0)

    ###Get list of contigs
    read = False
    while 1:
        line = handle.readline()
        line = line.strip()
        if line == '':
            continue
        if line == 'DETAILED DISPLAY OF CONTIGS':
            break
        if line[0] == '*':
            line = line.replace('*','').strip()
            currentContig = line
            contigs[currentContig] = []
            read = True
            continue
        if read:
            seqId = line.split()[0][:-1]
            contigs[currentContig].append(seqId)
            
    ###Get sequence of contigs
    while 1:
        line = handle.readline().strip()
        if line == 'CONSENSUS SEQUENCES':
            break
    
    recordsList = []
    for seq in SeqIO.parse(handle,'fasta'):
        seq = cleanSeq(seq)
        contigName = seq.description
        contigMembers = [seqId for seqId in contigs[contigName]]
        nMembers = len(contigMembers)
        sequenceName = '|'.join(contigMembers)
        recordsList.append([len(seq), SeqRecord(seq.seq, id=sequenceName),contigMembers])
        
    return recordsList

