#!/usr/bin/env python

# Copyright 2009, 2010 Yann Surget-Groba
# This file is part of the STM pipeline.
#
# The STM pipeline is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# The STM pipeline is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this file. If not, see <http://www.gnu.org/licenses/>.


import sys, getopt, seqUtils, pp, os, Bio, random, time
from Bio.SeqRecord import SeqRecord

def usage():
    print """map2contigs.py -i input map file -s input fasta -o output fasta -r reference [-j nJobs -n n local CPUs -h remote_hosts,]
    -i input map file: output file from mapBlastHit.py
    -s input fasta: fasta file with contig sequences
    -o output fasta
    optional parameters for parallel processing:
    -j number of jobs
    -n number of local CPU to use
    -h name of remote ppservers (coma separated list if several)"""

###Get command line arguments
try:
    opts,args = getopt.getopt(sys.argv[1:], "i:o:r:s:n:j:h:")
except getopt.GetoptError:
    usage()
    sys.exit(2)

nProc = 1
nJobs = 1
hosts = ()
infile = outfile = ref = seqFile = rmDupes = False

for o,a in opts:
    if o == '-i':
        infile = a
    if o == '-o':
        outfile = a
    if o == '-r':
        ref = a
    if o == '-s':
        seqFile = a
    if o == '-n':
        nProc = int(a)
    if o == '-j':
        nJobs = int(a)
    if o == '-h':
        hosts = a
if hosts != ():
    hosts = tuple([i for i in hosts.split(',')])
    
if not (infile and outfile and ref and seqFile):
	usage()
	sys.exit(1)

def bed2contigs(contigDic, seqDic, ref, jobNumber, seqFile, nJobs):

    outHandle = open("outfile_%i" % jobNumber, 'w')
    logHandle = open("logfile_%i" % jobNumber, 'w')
    seqHandle = open(seqFile)
    runHandle = open("runtime.log", 'a')
    
    orphans = []
    rejectedContigs = []
    nHits = 1
    nContigs = 1
    alphabet = Bio.Alphabet.Gapped(Bio.Alphabet.IUPAC.unambiguous_dna)
    
    for contig in contigDic:
        if nJobs == 1:
            print "Job %i: contig %s (%i out of %i)" % (jobNumber, contig, nHits, len(contigDic))
            #pass
        runHandle.write("Job %i: contig %s (%i out of %i)\n" % (jobNumber, contig, nHits, len(contigDic)))
        runHandle.flush()
            
        #Only one hit in contig: put hit in orphans
        if len(contigDic[contig]) == 1:
            readName = contigDic[contig][0].name
            seq = seqUtils.getSeqByPosition(seqHandle, seqDic[readName])
            orphans.append(seq)
            nHits += 1
            
        #More than one hit: build supercontig
        else:
            end = None
            overlap = False
            seqList = []
            #Make a list of all sequences in contig
            for hit in contigDic[contig]:
                seq = seqUtils.getSeqByPosition(seqHandle, seqDic[hit.name])
                if hit.strand == '-':
                    seq.seq = seq.seq.reverse_complement()
                if hit.start < end:
                    overlap = True
                end = hit.end
                seqList.append([hit.start, hit.end, seq])
            
            ###Presence of overlapping fragments
            if overlap:
                n = 0
                end = None
                overlapGroups = []
                
                #Get groups of overlapping fragments
                for sequence in seqList:
                    start = sequence[0]
                    #No overlap with previous sequence, create new overlap group:
                    if start >= end:
                        n += 1
                        overlapGroups.append([sequence])
                    #overlap with previous sequence, append sequence to previous overlap group
                    else:
                        overlapGroups[n-1].append(sequence)
                    #update end with last sequence
                    end = sequence[1]
                    
                seqList = []

                for group in range(len(overlapGroups)):
                    #if there's only one sequence in current group, append to seqList
                    if len(overlapGroups[group]) == 1:
                        seqList.append(overlapGroups[group][0])
                        
                    #else compute consensus sequence from overlapping fragments
                    else:
                        listOverlappingReads = overlapGroups[group]
                        
                        ###Align overlapping fragments using mapping position:
                        align = Bio.Align.Generic.Alignment(alphabet)
                        offset = 0
                        previousStart = listOverlappingReads[0][0]
                        for sequence in listOverlappingReads:
                            currentStart = sequence[0]
                            offset += currentStart - previousStart
                            previousStart = currentStart
                            seq = '-'*offset + str(sequence[2].seq)
                            align.add_sequence(sequence[2].description, seq)
                        consensus, score = seqUtils.consensus(align,0,0.75)
                        
                        ###If less than 1% mismatches in consensus, accept consensus sequence
                        if float(score)/len(consensus) < 0.01:
                            #Build seqrecord from consensus:
                            listReads = [read[2].id for read in listOverlappingReads]
                            seqRecord = Bio.SeqRecord.SeqRecord(Bio.Seq.Seq(consensus), description='|'.join(listReads))
                            
                            #Get position of consensus sequence and add it to seqList:
                            start = listOverlappingReads[0][0]
                            end = start + len(consensus)
                            seqList.append([start, end, seqRecord])
                        
                        ###Else assemble fragments using CAP and reject non-matching fragments
                        else:
                            logHandle.write("contig %s: Ambiguous assembly, using CAP\n" % contig)
                            logHandle.flush()
                            runHandle.write("\tAmbiguous assembly, using CAP\n")
                            runHandle.flush()
                            
                            #Removing duplicate sequences
                            tmpDic = {}
                            tmpList = []
                            for read in listOverlappingReads:
                                if not tmpDic.has_key(str(read[2].seq)):
                                    tmpDic[str(read[2].seq)] = 1
                                    tmpList.append(read)
                            tmpDic = None
                            listOverlappingReads = tmpList
                            tmpList = None
                            
                            #make temporary input file for CAP
                            while 1:
                                tmpName = int(random.random() * 1000000)
                                fastaName = '.%i.fa' % tmpName
                                outName   = '.%i.out' % tmpName
                                if not fastaName in os.listdir('.'):
                                    break
                            
                            tmpHandle = open(fastaName,'w')
                            dicReads = {}
                            posRead = 0
                            for seqRecord in listOverlappingReads:
                                dicReads[str(posRead)] = seqRecord[2].description
                                tmpHandle.write('>%i\n%s\n'%(posRead, seqRecord[2].seq))
                                posRead += 1
                            tmpHandle.close()
                            
                            #run CAP and get contigs of assembled fragments
                            failed = os.system('cap %s > %s' % (fastaName, outName))
                            if failed:
                                contigs = []
                                logHandle.write('CAP assembly failed !\n')
                            else:
                                tmpHandle = open(outName)
                                contigs = seqUtils.parseCapRes(tmpHandle)
                                tmpHandle.close()
                                os.remove(outName)
                                os.remove(fastaName)
                            
                            #Replace temporary fragment names by real names:
                            for assembly in contigs:
                                seqRecord = assembly[1]
                                tmpListReads = assembly[2]
                                listReads = [dicReads[id] for id in tmpListReads]
                                seqRecord.id = '|'.join(listReads)
                                
                            #if only one contig reconstructed add it to seqList
                            if len(contigs) == 1:
                                assembly = contigs[0]
                                start = listOverlappingReads[int(assembly[2][0])][0]
                                end = start + assembly[0]
                                seqList.append([start, end, assembly[1]])
                                
                            #else remove contigs from scaffold and add them to orphans
                            else:
                                for assembly in contigs:
                                    seqRecord = assembly[1]
                                    seqRecord.description = ''
                                    #keep contigs made from several sequences
                                    if len(assembly[2]) > 1:
                                        rejectedContigs.append(seqRecord)
                                    #put the other ones in orphans
                                    else:
                                        orphans.append(seqRecord)
                            logHandle.write('\n')
                            logHandle.flush()

            ###join non-overlapping fragments
            if len(seqList) == 0:
                continue
                
            seqList.sort()
            align = Bio.Align.Generic.Alignment(alphabet)
            offset = 0
            previousStart = seqList[0][0]
            previousEnd = previousStart
            for sequence in seqList:
                currentStart = sequence[0]
                currentEnd = sequence[1]
                #If overlap between "non-overlapping" fragments, shift fragments to remove overlap
                shift = currentStart - previousEnd
                if shift < 0:
                    offset += abs(shift) + 3 - (abs(shift)%3) + 3
                    logHandle.write('Contig %s: shift problem\n\n' % contig)
                
                offset += currentStart - previousStart
                previousStart = currentStart
                previousEnd = currentEnd
                seq = '-'*offset + str(sequence[2].seq)
                align.add_sequence(sequence[2].description, seq)
            consensus, score = seqUtils.consensus(align,0,0.75)
            
            ###Write contig to file
            contigName = contig
            seqId = "contig_%i_%06d|%s_%s|length=%i" \
                            % (jobNumber,nContigs,ref,contigName,len(consensus))
            seqRecord = Bio.SeqRecord.SeqRecord(Bio.Seq.Seq(consensus), id= seqId, description='')
            outHandle.write(seqRecord.format('fasta'))
            nContigs += 1
            nHits+=1
    outHandle.close()
    seqHandle.close()
    logHandle.close()
    runHandle.close()
    return (orphans,rejectedContigs)

runHandle = open('runtime.log','w')
date = time.localtime()
date = "%i-%i-%i %i:%i:%i" % (date[0],date[1],date[2],date[3],date[4],date[5])
s = '*' * len(date)
runHandle.write("%s\n%s\n%s\n\n"%(s,date,s))
runHandle.close()

###indexing of sequence file:
print "Indexing sequence file"
try:
    seqHandle = open(seqFile)
except IOError, strerror:
    print "Couldn't open seqFile %s: %s" % (seqFile, strerror)
    sys.exit(1)
seqDic = seqUtils.indexSeqFile(seqHandle)
print "Found %i sequences in file %s\n"%(len(seqDic),seqFile)
seqHandle.close()


###Parse BED file
print "Building contigs list\n"
contigDic = {}
try:
    inHandle = open(infile)
except IOError, strerror:
    print "Couldn't open infile %s: %s" % (infile, strerror)
    
for line in inHandle:
    line = line.strip().split('\t')
    try:
        contigDic[line[0]].append(seqUtils.Hit(name=line[3],start=line[1],end=line[2],strand=line[4]))
    except:
        contigDic[line[0]] = [seqUtils.Hit(name=line[3],start=line[1],end=line[2],strand=line[4])]
inHandle.close()

startTime = time.time()
###Split list of contigs according to number of processors
if nJobs > 1:
    lenContigDic = len(contigDic)
    listContigDic = []
    newDic = {}
    if lenContigDic % nJobs:
        n = lenContigDic / nJobs + 1
    else:
        n = lenContigDic / nJobs
    for contig in contigDic:
        newDic[contig] = contigDic[contig]
        if len(newDic) == n:
            listContigDic.append(newDic)
            newDic = {}
    if len(newDic) != 0:
        listContigDic.append(newDic)


if nJobs == 1:
    #Only one job, run it directly
    print "Assembling contigs (this can take a while)..."
    orphans,rejectedContigs = bed2contigs(contigDic, seqDic, ref, 0, seqFile, nJobs)
    
else:
    #Start job server and define arguments
    job_server = pp.Server(ncpus=nProc, ppservers=hosts)
    modules = ('Bio', 'Bio.SeqIO', 'seqUtils', 'os', 'random')
    print "Starting job server with %i local cpus and %i remote hosts" % (job_server.get_ncpus(), len(hosts))
    
    #Run jobs
    print "Assembling contigs (this can take a while)..."
    jobs=[]
    for jobID, contigDic in enumerate(listContigDic):
        #jobID = listContigDic.index(contigDic)
        args=(contigDic, seqDic, ref, jobID, seqFile, nJobs)
        print "submitting job %i" % jobID
        jobs.append(job_server.submit(bed2contigs, args=args, modules=modules))
    print "Jobs submitted"    
    
    orphans = []
    rejectedContigs = []
    for job in jobs:
        result = job()
        orphans.extend(result[0])
        rejectedContigs.extend(result[1])
    
    print
    job_server.print_stats()

print "\tAssembly finished in %d seconds\n" % (time.time() - startTime)

###Create file of unassembled sequences
print 'Get unassembled fragments'
orphansFile = outfile.split('.')[0] + '_orphans.fa'
orphansHandle = open(orphansFile, 'w')
Bio.SeqIO.write(orphans, orphansHandle, 'fasta')
orphansHandle.close()

rejectedFile = outfile.split('.')[0] + '_rejected.fa'
rejectedHandle = open(rejectedFile,'w')
Bio.SeqIO.write(rejectedContigs, rejectedHandle, 'fasta')
rejectedHandle.close()

#Remove previous output files if exist:
for outputFile in ['log.txt', outfile]:
    if outputFile in os.listdir('.'):
        os.remove(outputFile)
        
#Write output files
for jobNumber in range(nJobs):
    os.system('cat logfile_%i >> log.txt && rm -f logfile_%i' % (jobNumber, jobNumber))
    os.system('cat outfile_%i >> %s && rm -f outfile_%i' % (jobNumber, outfile, jobNumber))

runHandle = open('runtime.log','a')
runHandle.write("\nAssembly finished in %d seconds\n" % (time.time() - startTime))
runHandle.close()
