from motevo_stuff import *
from fitting_beta import *
from Bio import SeqIO
from concatenate_motifs import concatenate
import subprocess


def arguments():
    import argparse
    parser = argparse.ArgumentParser(description='Fits beta and prior, and also run calculate the enrichment scores')
    parser.add_argument('-w', '--wm',
                    action="store", dest="WM", type=str
                    )    
    parser.add_argument('-t', '--trainseq',
                    action="store", dest="trainSeq", type=str
                    )
    parser.add_argument('-s', '--testseq',
                        action="store", dest="testSeq", type=str
                        )
    parser.add_argument('-o', '--outdir',
                        action="store", dest="outdir", type=str
                        )
    parser.add_argument('-g', '--genome',
                        action="store", dest="GENOME", type=str
                        )    

    parser.add_argument('-m', '--motevo_path',
                        action="store", dest="motevo_path", type=str
                        )

    results = parser.parse_args()
    return results


def sum_of_posteriors_foreground_regions(fname):    
    posteriors = {}
    total_bg_posterior = 0.
    with open(fname) as file_handler:
        for line in file_handler:            
            row = line.split()
            posterior = float(row[2])
            if re.search('_reg\d+', row[-1]):
                posteriors.setdefault(row[-1].strip(), 0.0)
                posteriors[row[-1].strip()] += posterior
            else:
                total_bg_posterior += posterior
    return posteriors, total_bg_posterior

def calculate_enrichment_scores(siteFile, beta, length, res_filename):    
    """
    siteFile = sites file of motevo on testPool
    length = dict of lengths of fg and bg seqs (testPool)
    beta = fitted on trainingPool

    I want to compute for every fg peak i:
    -> r_i = log( (n_i + l_i*beta) / (n_b + l*beta) )
    where:
    - n_b is the average number of sites per bg seq
    - l is the average length of a bg seq (which is identical to average length of all seqs)
    Then:
    -> enrichment_score = exp( 1/F * sum_i( r_i ) )
    -> loglikratio = sum_i( r_i )
    """

    sites, tot_bg_sites = sum_of_posteriors_foreground_regions(siteFile)

    n_f = [] #n_f = n_i + l_i*beta
    B = 0.
    tot_bg_length = 0.
    tot_fg_sites = 0.

    for r_i, l_i in length.items():
        try:
            p = sites[r_i]
            n_f.append(p+l_i*beta)
            tot_fg_sites += p
        except KeyError:
            if re.search('reg', r_i): # i.e. it is a fg sequence
                n_f.append(l_i*beta)
            else:
                B += 1
                tot_bg_length += l_i
    n_f = np.array(n_f)

    n_b = tot_bg_sites/B
    l_b = tot_bg_length/B

    denominator = np.log(n_b + (l_b*beta))

    ll_ratio_i = np.log(n_f) - denominator

    return {'mean':np.exp(np.mean(ll_ratio_i)),
            'std':np.exp(np.std(ll_ratio_i)),
            'LL_ratio':np.sum(ll_ratio_i)}, n_b

def cleanup(infiles):
    """
    To delete the files in the scratch directory
    """
    for a_file in infiles:
        cmd = "rm -f '%s'" % a_file
        proc = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE, \
                            stderr=subprocess.PIPE)
        result = proc.communicate()
        if proc.returncode:
            print result[1]
            print 'Problem with these files: ' % infiles
    return 0 


def fittingParameters(WM, trainingPool, trainingLength, outdir, genome, motevo_path):
    siteFilename, priorFilename, paramFilename, motifName = run_motevo(WM, trainingPool, outdir, genome, motevo_path, priorFile=None, minposterior=0.0001)
    prior = extract_priors(priorFilename)
    beta = fit_beta(siteFilename, outdir, trainingLength)
    cleanup([siteFilename, paramFilename])
    return {'prior': prior, 'beta':beta}, priorFilename


def calculateEnrichmetScores(WM, testPool, testLength, params, outdir, genome, priorFile, motevo_path):
    siteFilename, priorFilename, paramFilename, \
                  WM = run_motevo(WM, testPool, outdir, \
                                  genome, motevo_path, priorFile=priorFile, \
                                  minposterior=0.0001)
    motifName = os.path.basename(WM)
    scores, mean_bg_sites = calculate_enrichment_scores(siteFilename, params['beta'], testLength, os.path.join(outdir, '%s.enrichment_score' % motifName))
    cleanup([siteFilename, priorFilename, paramFilename, priorFile])    
    return scores, motifName, mean_bg_sites


def lengthOfWM(WMfiles):
    minLength = np.inf
    i = 0
    for WMfile in WMfiles:
        i +=1
        length = 0
        try:
            with open(WMfile) as inf:
                for line in inf:
                    if re.search('^\d+\s+[(\.)0-9]+\s+[(\.)0-9]+\s+', line):
                        length += 1
            if length < minLength:
                minLength = length
        except IOError:
            print 'Couldnt find WMfile: %i' %1
            print WMfile
            print WMfiles
    return minLength    


def lengthOfSequences(trainingPool, testPool, WMfile):
    wmLength = lengthOfWM(WMfile)
    trainingLength, testLength = {}, {}
    with open(trainingPool) as inf:
        for record in SeqIO.parse(inf, 'fasta'):
            trainingLength[re.sub('^>', '', record.id)] = (len(record.seq) - wmLength)*2
    with open(testPool) as inf:
        for record in SeqIO.parse(inf, 'fasta'):
            testLength[re.sub('^>', '', record.id)] = (len(record.seq) - wmLength)*2
    return trainingLength, testLength


def main():
    args = arguments()    
    args.WM = args.WM.split(' ')  # to have a list of WMs, in case we running for more than one WMs
    if not args.WM == ['']:
        trainingLength, testLength = lengthOfSequences(args.trainSeq, args.testSeq, args.WM)
        params, priorFile = fittingParameters(args.WM, args.trainSeq, trainingLength, \
                                                  args.outdir, args.GENOME, args.motevo_path)
        enrichmentScores, motifName, mean_bg_sites = calculateEnrichmetScores(args.WM, \
                                                                   args.testSeq, testLength, params, \
                                                                   args.outdir, args.GENOME, priorFile, args.motevo_path)
        resFilename = os.path.join(args.outdir, motifName + '.results')
        with open(resFilename, 'w') as outf:
            outf.write('\t'.join([
                        '\t'.join(args.WM),
                        str(enrichmentScores['mean']),
                        str(enrichmentScores['std']),
                        str(enrichmentScores['LL_ratio']),
                        str(params['beta']),
                        str(params['prior']),
                        str(mean_bg_sites)
                        ]) + '\n')
    else:
        print 'No WM given.'
    
    
    
if __name__ == '__main__':
    main()
