import re, os
import numpy as np

"""
    Derivation and Explanation:

    New Enrichment Score (in the limit of infinite background)

    new likelihood function of all the (500) peaks is:
    - L = sum_i( log(n_i + l_i * beta) - log( l_b_avg*beta + n_b_avg ) )
    n_i number of sites in peak i
    l_i length of peak i
    l_b_avg average length of bg seqs
    n_b_avg average sites in bg seqs

    The derivative with respect to beta is:
    - dL/dbeta = sum_i(  l_i / (n_i + l_i*beta) -  l_b_avg / (n_b_avg + l_b_avg*beta)) 
    - dL/dbeta = sum_i(  1 / (n_i/l_i + beta) -  1 / (n_b_avg/l_b_avg + beta))

    f_i = n_i/l_i
    f = n_b_avg/l_b_avg

    - dL/dbeta = sum_i( 1 / ( f_i + beta ) - ( 1 / (f + beta) ) )

    multiplying by (f+beta) and noting that the sign of the derivative does not change upon this multiplication:
    - dL/dbeta * (f+beta) = sum_i( (f + beta) / ( f_i + beta ) - 1 )
    - dL/dbeta * (f+beta) = sum_i( (f + beta + f_i - f_i) / ( f_i + beta ) - 1 )
    - dL/dbeta * (f+beta) = sum_i( (f - f_i) / ( f_i + beta ) )

    Find 0 of this function to find ML-solution of beta:
    - dL/dbeta * (f+beta) = - sum_i( (f_i - f) / ( f_i + beta ) )

"""

def sum_of_posteriors_foreground_regions(fname):    
    posteriors = {}
    total_bg_posterior = 0.
    with open(fname) as file_handler:
        for line in file_handler:            
            row = line.split()
            posterior = float(row[2])
            if re.search('_reg\d+', row[-1]):
                posteriors.setdefault(row[-1].strip(), 0.0)
                posteriors[row[-1].strip()] += posterior
            else: # i.e. it is a bg seq
                total_bg_posterior += posterior

    return posteriors, total_bg_posterior


def fit_beta(siteFile, interm_dir, number_of_windows):
    """
    number_of_windows is a dictionary where the region ID is the key (the same ID as in the sitesfile) and the value is the length, where:
    every sequence has plus and minus strand, thus l_i = (len_i - len_wm)*2
    Thus, the values say on how many locations a WM can sit on a sequence.
    number_of_windows contains fg and bg seqs
    """

    binding_regions_sitecount, N_bg = sum_of_posteriors_foreground_regions(siteFile)

    # since bg seqs have the same lengths as the fg peaks, I can just take the average over all seqs
    l_bg_avg = np.mean(number_of_windows.values()) 

    f_i = []

    num_bg_seqs = 0
    ij = 0
    for r_i, l_i in number_of_windows.items():
        try:
            f_i.append( binding_regions_sitecount[r_i] / float(l_i))
            ij += 1
        except KeyError:
            if re.search('reg', r_i): # i.e. it is a fg sequence
                f_i.append(0.0)
            else:
                num_bg_seqs += 1

    f_i = np.array(f_i)

    n_bg_avg = N_bg/num_bg_seqs
    f = n_bg_avg/l_bg_avg

    b1 = f_i - f
    def deriv(beta_i):
        b2 = f_i + beta_i
        return -np.sum( b1 / b2)

    ## find range for beta. Go up until negative:
    beta_min = 1.0e-12
    beta_max = 10.
    betas = [beta_min]
    while True:
        if deriv(betas[-1]) > .0:
            betas.append(betas[-1]*2)
        elif betas[-1] > beta_max:
            print 'For beta_max = %f still no negative derivative. Using beta_max. (%s)' %(beta_max, siteFile)
            return beta_max
        else:
            break

    # derivative function goes from positive to negative. If derivative at beta_min is already negative or 0 the maximum of the likelihood is also there.
    if len(betas) == 1:
        print 'Deriv is already negative at beta_min %g. (%s)' %(beta_min, siteFile), ij, N_bg, np.sum(binding_regions_sitecount.values()), f, np.mean(f_i), deriv(beta_min)
        return beta_min

    beta_min, beta_max = betas[-2:]

    beta_mids = []
    while True:
        beta_mids.append(np.mean([beta_min, beta_max]))

        d = deriv(beta_mids[-1]) 
        if d > 0 :
            beta_min = beta_mids[-1]
        elif d < 0:
            beta_max = beta_mids[-1]
        else:
            return beta_mids[-1]

        if len(beta_mids) > 1:
            if abs(beta_mids[-1] - beta_mids[-2]) < 1.0e-12:
                break

    return beta_mids[-1]
