#!/usr/bin/env python

# Copyright 2009, 2010 Yann Surget-Groba
# This file is part of the STM pipeline.
#
# The STM pipeline is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# The STM pipeline is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this file. If not, see <http://www.gnu.org/licenses/>.

import sys, getopt, seqUtils
from Bio.Blast import NCBIXML
from Bio import SeqIO


def usage():
    print """Usage: mapBlastHit.py -i infile -o outfile -s seqfile [-d idCutoff -c coverageCutoff -v]
    -i infile: blast report file in xml format
    -o outfile: will be overwritten if exists
    -s seqfile: fasta file with contig sequences
    optional parameters:
    -d idCutoff (integer): minimum percent identity between the query and the subject sequence [default = 60]
    -c coverageCutoff (integer): minimum percentage of the query sequence covered by the hit [default = 90]
    -v: optional flag to print run informations"""
    
idCutoff  = 60
covCutoff = 90
infile = mapfile = verbose = False

try:
    opts, args = getopt.getopt(sys.argv[1:], "i:o:d:c:s:v")
except getopt.GetoptError:
    usage()
    sys.exit(1)

for o,a in opts:
    if o == '-i':
        infile = a
    if o == '-o':
        mapfile = a
    if o == '-v':
        verbose = True
    if o == '-d':
    	idCutoff = int(a)
    if o == '-c':
    	covCutoff = int(a)
    if o == '-s':
        seqfile = a

if not (infile and mapfile):
    usage()
    sys.exit(1)

#Open infile:
try:
    blastHandle = open(infile)
except IOError, strerror:
    print "Could not open file %s: %s" % (infile, strerror)
    sys.exit(1)
    
#Open seqfile:
try:
    seqHandle = open(seqfile)
except IOError, strerror:
    print "Could not open file %s: %s" % (seqFile, strerror)
    sys.exit(1)

#Create output files:
mapHandle = open(mapfile, 'w')
orphansFile = mapfile.split('.')[0] + '_orphans.fa'
orphansHandle = open(orphansFile, 'w')


###indexing of sequence file:
if verbose:
    print "Indexing sequence file"
seqDic = seqUtils.indexSeqFile(seqHandle)
if verbose:
    print "Found %i sequences in file %s\n"%(len(seqDic),seqfile)


###Parsing blast records
if verbose:
    print "Parsing BLAST records"
n=0
hits = []
mapDic = {}
for record in NCBIXML.parse(blastHandle):
    n+=1
    if not n%1000 and verbose:
        print n

    #get data from blast record
    readName = record.query
    readLength = record.query_length
    if len(record.alignments) == 0:
        continue
    mapFeature = record.alignments[0].title
    bestHit = record.alignments[0].hsps[0]
    frame = bestHit.frame
    queryStart = bestHit.query_start
    queryEnd = bestHit.query_end
    subjectStart = bestHit.sbjct_start
    matchLength = bestHit.align_length
    e = bestHit.expect
    
    #calculate percent identity and coverage
    identity = (float(bestHit.identities)/matchLength)*100
    coverage = (float(matchLength*3)/readLength)*100
        
    #reject hits with identity and/or coverage below cutoff values
    if identity < idCutoff or coverage < covCutoff:
        continue
    
    #hit passed cutoff values, get its map position
    hits.append(readName)
    if frame[0] > 0:
        strand = '+'
    if frame[0] < 0:
        strand = '-'
        queryStart = queryStart + matchLength * 3 - 1
    
    mapPosition = seqUtils.mapBlastHit(queryStart, subjectStart, strand, readLength)
    try:
        mapDic[mapFeature].append([mapPosition,mapPosition+readLength,readName,strand])
    except:
        mapDic[mapFeature]=[[mapPosition,mapPosition+readLength,readName,strand]]
    
    
###write mapping informations in mapfile
if verbose:
    print "Writing map file"
for feature in mapDic:
    hitList = mapDic[feature]
    hitList.sort()
    for hit in hitList:
        mapHandle.write(feature + "\t%i\t%i\t%s\t%s\n" % tuple(hit))


###Get list and sequence of unmapped fragments
misses = set(seqDic.keys()) - set(hits)
if verbose:
    print "Found %i records in query, %i have hits, making %i orphans" \
          % (len(seqDic.keys()), len(hits), len(misses))

    print "Get sequences of orphans"
for seqName in misses:
    seq = seqUtils.getSeqByPosition(seqHandle, seqDic[seqName])
    orphansHandle.write(seq.format('fasta'))

seqHandle.close()
mapHandle.close()
blastHandle.close()
orphansHandle.close()
