#!/usr/bin/env python3

import sys
import os
import gzip
import genutils3
import subprocess

###############################################################################
# define critera to extract reads for assembly
def to_extract(samRec):
    minSoftClip = 100
    minInsertion = 100

    if samRec['unMapped'] is True:
        return False
     
    if samRec['mapQ'] == 0: #skip the mapq0s
        return False

    if samRec['notPrimaryAlignment'] is True and samRec['isSupplementaryAlignment'] is False: # skip the mapq0s
        return False
    
    if samRec['isSupplementaryAlignment'] is True:
        return True
        
    if samRec['cigarCounts']['S'] >= minSoftClip:
        return True
    
    for i in samRec['cigarExpand']:  
        if i[1] == 'I' and i[0] >= minInsertion:
            return True
            
    return False 
###############################################################################        
def get_sa_region(tagsList):
    for i in tagsList:
        if i[0:2] == 'SA':
            i = i.split(':')
            reg = i[2]
            return(reg)
    return('NOTFOUND')
###############################################################################        
def search_region_for_read(bamName,namesToDoDict,targetRegionStr):
    fqHits = {}
    cmd = 'samtools view %s %s ' % (bamName, targetRegionStr)
    samPipe = subprocess.Popen(cmd,text=True,shell=True,stdout = subprocess.PIPE)
    for samLine in samPipe.stdout:
        samLine = samLine.rstrip()
        samLine = samLine.split()
        samRec = genutils3.parse_sam_line(samLine)
        if samRec['seqName'] not in namesToDoDict:
            continue
        
        if samRec['cigarCounts']['H'] != 0 or samRec['seq'] == '*':
            continue
        
        if samRec['reverseStrand'] is True:
            seq = samRec['seq']
            seq = genutils3.revcomp(seq)
            qual =  samRec['qual']
            qual = qual[::-1] #reverse
        else:
            seq = samRec['seq']
            qual =  samRec['qual']         
        
        fqHits[samRec['seqName']] = [seq,qual]
    return fqHits
###############################################################################        
def extract_long_reads(mappedBamFileName,regionStr,outFastqFileName,debug=False):
    # first, go through and select reads to extract
    # if possible, grab the sequence/qual as well in this pass
    extractData = {}
    namesToDo = {} # store names of reads to take
    extractData = {} # store sequence of read
    allReads = {}

    cmd = 'samtools view %s %s ' % (mappedBamFileName, regionStr)
    if debug is True:
        print('cmds is',cmd,flush=True)


    samPipe = subprocess.Popen(cmd,text=True,shell=True,stdout = subprocess.PIPE)
    for samLine in samPipe.stdout:
        samLine = samLine.rstrip()
        samLine = samLine.split()
        samRec = genutils3.parse_sam_line(samLine)
        
        if debug is True:
            allReads[samRec['seqName']] = 1
        
        if to_extract(samRec) is True:
            namesToDo[samRec['seqName']] = 1
            # check to get the seq.  Only take from non-hard clipped
            # some reads have mulitple alignments in interval, so use dictionary to check
            # note -- minimap2 sometimes puts supplementary reads with seq len 0 to save space 
            if samRec['cigarCounts']['H'] == 0 and samRec['seq'] != '*': 
                if samRec['reverseStrand'] is True:
                    seq = samRec['seq']
                    seq = genutils3.revcomp(seq)
                    qual =  samRec['qual']
                    qual = qual[::-1] #reverse
                else:
                    seq = samRec['seq']
                    qual =  samRec['qual']         
                extractData[samRec['seqName']] = [seq,qual]        

    if debug is True:
        print('Total reads in interval: %i' % (len(allReads)),flush=True )
        print('Total reads to extract: %i' % (len(namesToDo)),flush=True )
        # count of reads need second pass to extract out ones that are missing:
        nMissing = 0
        for n in namesToDo:
            if n not in extractData:
                nMissing += 1
        print('There are %i missing seq data after pass one'% nMissing)
        print('Starting second pass to get the missing reads')

    samPipe = subprocess.Popen(cmd,text=True,shell=True,stdout = subprocess.PIPE)
    for samLine in samPipe.stdout:
        samLine = samLine.rstrip()
        samLine = samLine.split()
        samRec = genutils3.parse_sam_line(samLine)
        if samRec['seqName'] in namesToDo and samRec['seqName'] not in extractData:
            if debug is True:
                print('pass2 for',samRec['seqName'])
            
            # check if can take from here, no hard clips
            if samRec['cigarCounts']['H'] == 0 and samRec['seq'] != '*':
                if samRec['reverseStrand'] is True:
                    seq = samRec['seq']
                    seq = genutils3.revcomp(seq)
                    qual =  samRec['qual']
                    qual = qual[::-1] #reverse
                else:
                    seq = samRec['seq']
                    qual =  samRec['qual']         
                extractData[samRec['seqName']] = [seq,qual]    
            else:
                if debug is True:
                    print('checking other regions',samRec['seqName'])                
                SAregion = get_sa_region(samRec['otherTags'])                
                if debug is True:
                    print(SAregion)
                if SAregion == 'NOTFOUND':
                    continue
                else: # search region
                    SAregion  = SAregion.split(',')
                    targetC = SAregion[0]
                    targetB = int(SAregion[1])
                    targetE = targetB + 1
                    targetRegionStr = '%s:%i-%i' % (targetC,targetB,targetE)
                    
                    # get names to do
                    namesToDoDict = {}
                    for n in namesToDo:
                         if n not in extractData:
                             namesToDoDict[n] = 1
                    if debug is True:
                        print('have %i names to search for' % len(namesToDoDict))
                    
                    
                    fqRecs = search_region_for_read(mappedBamFileName,namesToDoDict,targetRegionStr)    
                    if debug is True:
                        print('found %i records that were missing' % len(fqRecs) )
                    for r in fqRecs:
                        extractData[r] = fqRecs[r]
                    
                    if debug is True and samRec['seqName'] in extractData:
                        print('Found what we were looking for!')

                    if samRec['seqName'] not in extractData:
                        print('Did not find %s, do not know what to do' % samRec['seqName'])
                                        

    nMissing = 0
    for n in namesToDo:
        if n not in extractData:
            nMissing += 1

    if debug is True:
        print('second pass complete!')
        # count of reads need second pass to extract out ones that are missing:
        print('There are %i missing seq data after pass one'% nMissing)
    
    if nMissing != 0:
        print('Error!!!\nI could not find all the reads to extract. I do not know what to do!!')
        for n in namesToDo:
            if n not in extractData:
                print('%s is missing!' % n)
        sys.exit()
        


    if outFastqFileName[-3:] == '.gz': # is gzipped
        outFile = gzip.open(outFastqFileName,'wt')
    else:
        outFile = open(outFastqFileName,'w')

    if debug is True:
        print('writing to',outFastqFileName)

    for i in extractData.keys():
        outFile.write('@%s\n%s\n+\n%s\n'% (i,extractData[i][0],extractData[i][1]))
    outFile.close()

    if debug is True:
        print('DONE!')

###############################################################################        



#### main program
###############################################################################        

if __name__ == "__main__":
    from optparse import  OptionParser 

    USAGE = """USAGE:python extract_for_assem.py --bam <bed file> --region <region str chr:start-end> --outfq <fastq file name for reads>
        
    
    Extracts reads that do not match reference from the given region and write to file
    Assumes reads are single long reads aligned with minimap2.
    Used to get subsets for local assembly of variants.
    """

    parser = OptionParser(USAGE)
    parser.add_option("--bam",dest="bamFile",help="name of sorted and indexed bam file")
    parser.add_option("--region",dest="regionStr",help="string of region to extract")    
    parser.add_option("--outfq",dest="outFq",help="name of output fq file. gzip if ends in .gz")        
    parser.add_option('--debug',dest='isDebug', action='store_true', help = 'print debug info to screen', default = False)
    
    
    (options,args)=parser.parse_args()


    if options.bamFile is None:
        parser.error('bam file not given')
    if options.regionStr is None:
        parser.error('region not given')
    if options.outFq is None:
        parser.error('output fastq file name not given')

    
    extract_long_reads(options.bamFile,options.regionStr,options.outFq,options.isDebug)

    



