import sys
import cPickle
import pysam
import scipy.stats
import random
import getopt
import os.path
from Bio.Seq import Seq
import re
import argparse

## NOTE: this may require python 2.7.5...

#W1_SPACER = 'GAGTGATTGCTTGTGACGCCTT'   # spacer between BC1 and BC2
W1_SPACER = 'AAGGCGTCACAAGCAATCACTC'   # RC of spacer
BC1_LEN = 8    # BC1 length
BC2_LEN = 8    # BC2 length
UMI_LEN = 6    # UMI length
MAX_W1MISMATCH = 3
MAX_BCMISMATCH = 1   # number of bases that can be error-corrected

# Argument parsing:
def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('-i', type=str, required=True, metavar='<R1 file>',
                        help='input R1 fastq file')
    parser.add_argument('-o', type=str, required=False,default=None, metavar='<outDir>',
                        help='output file directory')
    parser.add_argument('-b', type=str, required=True, metavar='<validBcFile>',
                        help='valid barcode list file')
    args = parser.parse_args()
    return args

def patMatch(x, y, n):
    # default match location:
    loc = None

    # attempt to find pattern 'x' in query sequence 'y', allowing a maximum of 'n' mismatches:
    lx = len(x)
    ly = len(y)
    
    for i in range(ly-lx):   # loop over each possible start location
        nm = 0    # number of mismatches found
        for j in range(lx):   # loop over each base of the target pattern
            if x[j]!=y[i+j]:
                nm+=1        # mismatch
                if nm>n:
                    break    # too many mismatches
        if nm<=n:
            loc = i   # match location
            break

    return loc

### Barcode error-correction:
def correctBC(bc, bcDict):
    bcNew = None    # default (uncorrectable)
    nMatch = 0
    for k in bcDict.keys():
        if hammingDist(bc, k)<=MAX_BCMISMATCH:
            bcNew=k   # corrected barcode
            nMatch+=1 # count the number of possible corrections
            ##break     # don't bother looking any further
    
    if nMatch>1:
        bcNew = None  # barcode can't be unambiguously corrected

    return bcNew   # return corrected barcode

def hammingDist(x, y):
    # find the Hamming distance between two input strings:
    if len(x)!=len(y):
        hd = len(x)
    else:
        hd = 0
        for i in range(len(x)):
            if x[i]!=y[i]:
                hd+=1    # count mismatches
    return hd

def writeFastqLine(oFile, fqLine):
    # write to a file, assuming that fqLine is a pysam.cfaidx.FastqProxy object:
    oFile.write('@%s %s\n' % (fqLine.name, fqLine.comment))
    oFile.write('%s\n' % fqLine.sequence)
    oFile.write('+%s %s\n' % (fqLine.name, fqLine.comment))
    oFile.write('%s\n' % fqLine.quality)

def parseBarcodeAndUmi(r1, bcDict):
    bc1 = None
    bc2 = None
    umi = None
    w1Missing = 0
    w1Mismatch = 0
    bc1Trunc = 0
    bc1Valid = 0
    bc2Trunc = 0
    bc2Valid = 0
    bc1Corrected = 0   # indicate if BC1 or BC2 were error-corrected
    bc2Corrected = 0

    # locate the beginning of the spacer in the read sequence:
    try:
        # split to remove the spacer:
        fields = r1.sequence.split(W1_SPACER)

        # if no W1 spacer exact match, check for partial match:
        if len(fields)!=2:
            loc = patMatch(W1_SPACER, r1.sequence, MAX_W1MISMATCH)
            if loc!=None:
                fields = [r1.sequence[:loc], r1.sequence[loc+len(W1_SPACER):]]
                w1Mismatch = 1     # W1 mismatch
            else:
                w1Missing = 1      # W1 spacer not found

        if not w1Missing:
            ## Left barcode:
            if len(fields[0])>=8:
                bc1_tmp = fields[0][-8:]
                if bcDict.has_key(bc1_tmp):
                    bc1 = bc1_tmp
                    bc1Valid = 1
                else:
                    validBC = correctBC(bc1_tmp,bcDict)
                    if validBC!=None:
                        bc1 = validBC
                        bc1Valid = 1
                        bc1Corrected = 1    # indicate that BC1 was corrected
            else:
                # indicate that the left sequence was too short
                bc1Trunc = 1

            if bc1!=None:
                ## Right barcode:
                bc2_tmp = fields[1][:8]
                if bcDict.has_key(bc2_tmp):
                    bc2 = bc2_tmp
                    bc2Valid = 1
                else:
                    validBC = correctBC(bc2_tmp,bcDict)
                    if validBC!=None:
                        bc2 = validBC
                        bc2Valid = 1
                        bc2Corrected = 1    # indicate that BC2 was corrected

            if bc2!=None:
                ## UMI:
                if len(fields[1])>=14:
                    umi = fields[1][8:14]
                else:
                    bc2Trunc = 1
    except:
        pass

    return [bc1, bc2, umi, w1Missing, bc1Trunc, bc1Valid, bc1Corrected, bc2Trunc, bc2Valid, bc2Corrected, w1Mismatch]

def fastqWrite(f, r, rName):
    # write each field of the read:
    f.write('@%s %s\n' % (rName, r.comment))     # read name and comment
    f.write('%s\n' % r.sequence)                  # the sequence
    f.write('+%s %s\n' % (rName, r.comment))     # read name and comment (filler?)
    f.write('%s\n' % r.quality)                   # the quality string
    return

#def run(basename, outdir):
def run(args):
    basename = args.i       # forward reads file with barcode and UMIs
    outdir = args.o   # output directory
    if outdir==None:
        outdir = os.path.dirname(basename)  # if not provided, put the output file in the same directory as the input
    bcFile = args.b   # valid barcode file

    # load the valid barcode dictionary:
    #fIn = open('gel_barcode1_list.txt','r')
    fIn = open(bcFile,'r')
    bcSet = {}
    while 1:
        line = fIn.readline()
        if not line:
            break
        # skip the header line:
        if line.startswith('well'):
            continue
        if line.endswith('\n'):
            line=line[:-1]

        fields = line.split('\t')

        bc = str(Seq(fields[0]).reverse_complement())
        if len(bc)==8:
            bcSet.setdefault(bc,0)
        elif len(bc)==9:
            bcSet.setdefault(bc[1:],0)
        elif len(bc)==10:
            bcSet.setdefault(bc[2:],0)
        elif len(bc)==11:
            bcSet.setdefault(bc[3:],0)

    fIn.close()

    # construct the names of the input files:
    f1 = basename
    f2 = basename.replace('_R1','_R2')
    # and output fastq file:
    #f2out = '%s_valid_2.fastq' % basename
    fields = f2.split('_')
    for i in range(len(fields)):
        if i==0:
            f2out = fields[i]
        elif i==(len(fields)-1):
            f2out = f2out+'_valid_'+fields[i]
        else:
            f2out = f2out+'_'+fields[i]

    f2out =  os.path.join(outdir, os.path.basename(f2out))
    f2out = f2out.replace(".gz","")    # just in case...

    # open the input file:
    fq1 = pysam.Fastqfile(f1)
    fq2 = pysam.Fastqfile(f2)
    fq2out = open(f2out, 'w')

    eFlag = False     # error flag
    rCount = 0
    bcDict = {}
    bcValid = 0    # valid barcodes
    w1Missing_all = 0
    w1Mismatch_all = 0
    bc1Trunc_all = 0
    bc1Valid_all = 0
    bc1Corrected_all = 0
    bc2Trunc_all = 0
    bc2Valid_all = 0
    bc2Corrected_all = 0

    umiValid = 0   # valid UMI
    umiError = 0   # UMI read error
    countMod = 100000

    # loop over all reads:
    while 1:
        try:
            r1 = fq1.next()     # R1 read
            r2 = fq2.next()     # R2 read
            rCount+=1           # read counter
            if not rCount%countMod:
                print 'read %d' % rCount
        except StopIteration:

            break      # last item
        except:
            print 'pysam.Fastqfile iterator error.'
            eFlag = True
            break

        # get the barcode from the R1 read
        #[bc1, bc2, umi] = getBarcodeAndUmi(r1, bcSet, fOut)
        [bc1, bc2, umi, w1Missing, bc1Trunc, bc1Valid, bc1Corrected, bc2Trunc, bc2Valid, bc2Corrected, w1Mismatch] = parseBarcodeAndUmi(r1, bcSet)

        ## update counts:
        w1Missing_all+=w1Missing
        w1Mismatch_all+=w1Mismatch
        bc1Trunc_all+=bc1Trunc
        bc1Valid_all+=bc1Valid
        bc1Corrected_all+=bc1Corrected
        bc2Trunc_all+=bc2Trunc
        bc2Valid_all+=bc2Valid
        bc2Corrected_all+=bc2Corrected

        ## update counts for this barcode:
        if bc1!=None and bc2!=None and umi!=None:
            # only save if UMI is valid:
            if (umi.count('N')==0):
                # create the new R2 read name with barcode and UMI:
                #rName = '%s:%s:%s:%s' % (r2.name, bc1, bc2, umi)
                rName = '%s:%s%s:%s' % (r2.name, bc1, bc2, umi)   # removed : between BC1 and BC2 to make it consistent with other methods
                # write the read out to the new R2 fastq file:
                fastqWrite(fq2out, r2, rName)

                # save the 
                bc = '%s%s' % (bc1, bc2)
                bcValid+=1
                bcDict.setdefault(bc,{})
                bcDict[bc].setdefault('count', 0)
                bcDict[bc]['count']+=1
                #bcDict[bc].setdefault('reads', {})    # list of reads is not necessary
                #bcDict[bc].setdefault('reads',{})
                #bcDict[bc]['reads'][re.sub('1$','2',r1.name)]=umi   # name of R2 read to keep

    # close the input files:
    fq1.close()
    fq2.close()
    fq2out.close()

    # print counts:
    print 'Total reads: %d' % rCount
    print 'Barcodes: valid: %d, W1 missing: %d, W1 mismatch: %d, BC1 truncated: %d, BC1 valid: %d, UMI truncated: %d, BC2 valid: %d' % (bcValid, w1Missing_all, w1Mismatch_all, 
                                                                                                                       bc1Trunc_all, bc1Valid_all, 
                                                                                                                       bc2Trunc_all, bc2Valid_all)
    print '          BC1 corrected: %d, BC2 corrected: %d' % (bc1Corrected_all, bc2Corrected_all)
    print 'Barcodes: valid: %6.3f%%, W1 missing: %6.3f%%, W1 mismatch: %6.3f%%, BC1 truncated: %6.3f%%, BC1 valid: %6.3f%%, UMI truncated: %6.3f%%, BC2 valid: %6.3f%%' % (100.0*bcValid/rCount, 
                                                                                                                                                                          100.0*w1Missing_all/rCount,100.0*w1Mismatch_all/rCount,  
                                                                                                                       100.0*bc1Trunc_all/rCount, 100.0*bc1Valid_all/rCount, 
                                                                                                                       100.0*bc2Trunc_all/rCount, 100.0*bc2Valid_all/rCount)
    print '          BC1 corrected: %6.3f%%, BC2 corrected: %6.3f%%' % (100.0*bc1Corrected_all/rCount, 100.0*bc2Corrected_all/rCount)

    ## Save the barcode counts to an output file:
    outFile = os.path.join(outdir, os.path.basename('%s_bcCounts.txt' % basename))
    print 'Output counts file: %s' % outFile
    fOut = open(outFile,'w')
    fOut.write('BC\tcount\n')
    counts = []
    for bc in bcDict.keys():
        count = bcDict[bc]['count']
        fOut.write('%s\t%d\n' % (bc,count))
        counts.append(count)
    fOut.close()

    # ## Remove all entries with total reads < countThresh:
    # countThresh = scipy.percentile(counts, 98.0)   # 98th percentile of counts
    # print 'Ignoring barcodes with counts less than %d...' % countThresh

    # # build a dictionary containing all reads from bacrcodes with enough reads:
    # readDict = {}
    # for bc in bcDict.keys():
    #     if bcDict[bc]['count']>countThresh:
    #         for r in bcDict[bc]['reads'].keys():
    #             readDict[r]= [bc, bcDict[bc]['reads'][r]]     # identify the read with the barcode and UMI

    # outFile = os.path.join(outdir, '%s_validReads.pkl' % basename)
    # print 'Outut valid reads file: %s' % outFile
    # fOut = open(outFile, 'wb')
    # cPickle.dump(readDict, fOut)
    # fOut.close()

    return

    
if __name__ == "__main__":

    args = get_args()
    run(args)

    # # process input arguments with getopt:
    # if len(sys.argv)==3:
    #     run(sys.argv[1], sys.argv[2])       # default: input baseFile name, output directory
    # else:
    #     print "usage error: splitFastqByBarcode.py.py <input file basename> <output directory>"

