import scipy as SP
import pylab as PL
import pdb

""" Infer posterior mu(y) given clusterings and variances. Average in window given by dist
@param x length-L location of probe
@param y length-L probe value
@param z length-LxC probe assignment
@param s2 length-C cluster variance
@param dist averaging window in x-scale
@return posterior mean, posterior standard deviation estimate, average responsibility of probes used in estimate
"""
def average_y(x, y, z, s2, dist=15000):
    res = SP.zeros([len(x), len(s2)])
    #sdhat = SP.zeros(len(x))
    #meanz = SP.zeros(len(x))
    
    for l in range(len(x)): # for each locus
        I = SP.where(abs(x[l] - x) < dist)[0] # filter on nearby probes
        n = z[I].sum(axis=0) # effective number of individuals
        res[l] = (y[I]*z[I].T).sum(axis=1)/n # weighed average 
                                       # sdhat[l] = s2/n # variance is scaled accordingly
                                       # meanz[l] = n/len(I) # average is the fraction of effective number of individuals of all individuals
        
    return res #, sdhat**0.5, meanz


""" Infer cluster responsibilities given cluster averages and variances, and data
@param mu length-L cluster 2 mean
@param s2 length-C cluster variance
@param y length-L data
@return length-LxC responsibilities"""
def update_z(mu, s2, y):
    res = -0.5*SP.log(s2) - 0.5*((y - mu.T)**2).T/s2 # log-likelihood of data under all C cluster models
    res = res.T - res.max(axis=1) # scaled so that largest one has p=1
    res = SP.exp(res) # get real probabilities
    res = res/(res.sum(axis=0)) # and normalise
    return res.T


""" Infer cluster variances given mean, data, and responsibilities
@param mu length-LxC cluster means
@param y length-L data
@param Z length-LxC cluster responsibilities
@return length-C variance estimates """
def update_s2(mu, y, Z):
    N = len(y)
    res = SP.zeros(mu.shape[1])    
    priors2 = [0.3**2, 0.1**2,0]  # prior parameters for inverse gamma distribution
    priorN = [N/10., N/3., 0]

    for i in range(len(res)): # for each cluster
        res[i] = (priors2[i]*priorN[i] + ((Z[:,i]*((y - mu[:,i])**2)).sum()))/(Z[:,i].sum() + priorN[i]) # update the variance estimate 
    return res


""" """
def classify_probes(loc, val, plot=False, threshold=0.9, n_iter=20):
    x, y, N = SP.array(loc), SP.array(val), len(loc)
    Z = SP.ones([N,3])/3.
    s2 = SP.array([0.2**2, 0.1**2, 0.1**2])
    mu = average_y(x, y, Z, s2) # initialise variables
    mu[:,0:2] = 0

    for i in range(n_iter): # for number of iterations
        Z = update_z(mu, s2, y) # update responsibilities
        mu = average_y(x,y,Z, s2) # update mean
        mu[:, 0:2] = 0
        s2 = update_s2(mu, y, Z) # update variance
        
    post = None
    if plot:
        for i in range(3):
            I = SP.where(Z[:,i] > threshold)[0]
            PL.plot(x[I], y[I], ".", markersize=12, alpha=0.1+0.3*(i==2))
        PL.plot(x, mu[:,2], linewidth=2)
        PL.plot([min(x),max(x)], [0,0], 'k--', linewidth=1)
        PL.show()
        
    return mu, s2, Z, post


def infer_mean(loc, val, threshold=0.9, n_iter=20, return_all=False, savefile=None):
    m,s,z,post = classify_probes(loc, val, threshold=threshold, n_iter=n_iter)
    if return_all:
        return m[:,2], s[2], z    
    return m[:,2]

