import numpy as np
import sys
from scipy.stats.mstats import gmean
from statsmodels.nonparametric.smoothers_lowess import lowess
sys.path.append('/project/jcreminslab/dagillis_projects/5C_peak_filter/'
                'stand_alone_modules')
import mat2counts as m2c
from lib5c.parsers.counts import load_counts


def generate_offdiagonals(matrix):
    '''
    Returns all of the elements from each offdiagonal, where an
    offdigonal d is all points (i, j) in the matrix such that (j - i) = d.
    This function also removes NaN values from these diagonals,
    unless all values at the distance are NaN.

    Parameters
    ----------
    matrix : np.ndarray
        Matrix to get offdiagonals from

    Returns
    ----------
    off_diagonals : list of lists
        Each element is a list containing all values from each offdiagonal.
        These lists are in order of distance (j - i).
    '''
    off_diagonals = []
    for i in range(len(matrix)):
        odiag = np.diag(matrix, i)
        odiag2 = [elem for elem in odiag if np.isfinite(elem)]
        if len(odiag2) == 0:
            odiag2 = [np.NaN]
        off_diagonals.append(odiag2)
    return off_diagonals


def regional_mean(offdiagonals, logged=True):
    '''
    Compute and return the 1D empirical global mean.

    Parameters
    ----------
    offdiagonals : list of lists
        Each element is a list containing all values from each offdiagonal.
        These lists should be in order of distance (j - i).

    logged : boolean
        Pass True if values are logged to calculate the arithmetic mean
        in log space.

    Returns
    ----------
    means : np.array
        Numpy array of the geometric mean at each distance.
    '''
    # get means of each off-diagonal
    if logged:
        means = np.array([np.mean(offdiagonal)
                          for offdiagonal in offdiagonals])
    else:
        means = np.array([gmean(offdiagonal) for offdiagonal in offdiagonals])

    return means


def lowess_smooth_expected(offdiagonals, logged=True, insertion=None):
    '''
    Merges all data from the upper triangular part of each region.

    Parameters
    ----------
    offdiagonals : list of lists
        Each element is a list containing all values from each offdiagonal.
        These lists should be in order of distance (j - i).

    logged : boolean
        Pass True if values are logged to calculate the arithmetic mean
        in log space.

    insertion : float
        Fraction of the 1D global mean at which to begin lowess smoothing.
        If None, the function will use the default values. Note that the
        default insertion points differ for logged and unlogged data (1/3
        and 1/2, respectively). This is directly descended from the Hb and A
        pipelines.

    Returns
    ----------
    joined_fit : list
        Numpy array containing the lowess smoothed geometric mean at each
        distance.
    '''
    # record region size
    region_size = len(offdiagonals)

    # get means of each off-diagonal
    if logged:
        means = np.array([np.mean(offdiagonal)
                          for offdiagonal in offdiagonals])
    else:
        means = np.array([gmean(offdiagonal) for offdiagonal in offdiagonals])

    # data is [distance, count] for every interaction
    data = np.asarray([[i, offdiagonal_entry]
                       for i in range(region_size)
                       for offdiagonal_entry in offdiagonals[i]])

    mean_split = [x for x in means if np.isfinite(x)]
    region_size_split = len(mean_split)

    # determine location at which to begin lowess smoothing
    if insertion is None:
        if logged:
            insertion = 1/3.
        else:
            insert_point = 0.5
    key_index = int(region_size_split * insertion)

    # include data from the key index onward
    filtered_data = np.asarray(filter(lambda x: x[0] >= key_index, data))

    fit = lowess(filtered_data[:, 1], filtered_data[:, 0], frac=0.8, it=3)

    # filter the fit to exclude the values that were at the key index
    filtered_fit = np.asarray(filter(lambda x: x[0] > key_index, fit))
    # construct an array that will represent the joined fit
    joined_fit = np.zeros(region_size)
    for i in range(key_index + 1):
        joined_fit[i] = means[i]
    for i in range(key_index + 1, region_size):
        query_result = filter(lambda x: x[0] == i, filtered_fit)
        if query_result:
            joined_fit[i] = query_result[0][1]
        else:
            joined_fit[i] = means[i]

    return joined_fit


def expected_matrix(matrix, logged=True, insertion=None):
    '''
    Compute and return the expected matrix from the lowess smoothed empirical
    regional mean given an observed matrix.
    '''
    one_d_exp = lowess_smooth_expected(generate_offdiagonals(matrix),
                                       logged=logged, insertion=insertion)
    exp_mat = np.zeros(matrix.shape)
    for i, exp in enumerate(one_d_exp):
        for j in xrange(len(matrix) - i):
            exp_mat[j, j+i] = exp
            exp_mat[j+i, j] = exp_mat[j, j+i]
    return exp_mat


def expected_klf(obs_mat, bin_step, logged=True, insertion=None,
                 ty='lowess'):
    '''
    Returns a dictionary with each Klf4 regional expected.
    bin_step should be given in base pairs (i.e. 4kb should be 4000).
    '''

    # this is the location of the cutoff in base pairs
    cut_off = 472000

    # additional one is to account for the first bin (0 distance)
    num_bins = (cut_off / bin_step) + 1

    sub_klf4 = obs_mat[:num_bins, :num_bins]
    klf_out = {}
    sub_diagonals = generate_offdiagonals(sub_klf4)
    if ty == 'lowess':
        klf1 = lowess_smooth_expected(sub_diagonals, logged=logged,
                                      insertion=insertion)
    else:
        klf1 = regional_mean(sub_diagonals, logged=logged)
    klf_out['Klf4_1'] = klf1

    full_diagonals = generate_offdiagonals(obs_mat)

    remaining_diagonals = []
    for i, full in enumerate(full_diagonals):
        if i < len(sub_diagonals):
            sub = sub_diagonals[i]
            sub_len = len(sub)
            remaining_diagonals.append(full[sub_len:])
        else:
            remaining_diagonals.append(full)
    # get remaining diagonal values
    if ty == 'lowess':
        klf2 = lowess_smooth_expected(remaining_diagonals, logged=logged,
                                      insertion=insertion)
    else:
        klf2 = regional_mean(remaining_diagonals, logged=logged)

    klf_out['Klf4_2'] = klf2

    return klf_out


def expected_dictionary(obs_dic, logged=True, insertion=None, bin_step=None):
    '''
    Compute and return a dictionary of global expected matrices for each
    region.
    '''
    # create dictionary of expected matrices
    if bin_step is None:
        raise Exception('Bin_Step value must be specified')
    exp_dic = {}
    for region in obs_dic:
        if region == 'Klf4':
            k_1d_exp_dic = expected_klf(obs_dic[region], bin_step,
                                        logged=logged, insertion=insertion)
            k_exp_dic = {}
            for section in k_1d_exp_dic:
                k_exp = k_exp_dic[section]
                exp_mat = np.zeros(len(k_exp), len(k_exp))
                for i, exp in enumerate(k_exp):
                    for j in xrange(len(matrix) - i):
                        exp_mat[j, j+i] = exp
                        exp_mat[j+i, j] = exp_mat[j, j+i]
                k_exp_dic[section] = exp_mat
            exp_dic.update(k_exp_dic)
        else:
            exp_mat = expected_matrix(obs_dic[region],
                                      logged=logged, insertion=insertion)
            obs_shape = obs_dic[region].shape
            reg_exp_mat = np.zeros(obs_shape)
            for i in xrange(obs_shape[0]):
                for j in xrange(obs_shape[1]):
                    reg_exp_mat[i, j] = exp_mat[i, j]
            exp_dic[region] = reg_exp_mat
    return exp_dic


def write_expected_out(obs_file, outpath, logged=True, insertion=None,
                       bin_step=None, ty='lowess'):
    '''
    Writes the expected out as a flattened expected file.
    The format of this is similar to, but not quite the same as .scaled files.
    The purpose of this format is to save space and allow easier import as
    1D expected files.

    Format of output txt file:
    # Region 1
    value_at_0_loci_distance
    value_at_1_loci_distance
    # Region 2
    etc.

    Note: Writes out the two different Klf4 distances as two "different"
    regions (Klf4_1 and Klf4_2 where Klf4_1 is the TAD)
    '''
    if bin_step is None:
        raise Exception('Bin_Step value must be specified')

    obs_dic = load_counts(obs_file)
    exp = {}
    for region in obs_dic:
        if region == 'Klf4':
            exp.update(expected_klf(obs_dic[region], bin_step,
                                    logged=logged, insertion=insertion,
                                    ty=ty))
        else:
            if ty == 'lowess':
                one_d_exp = lowess_smooth_expected(
                    generate_offdiagonals(obs_dic[region]), logged=logged,
                    insertion=insertion)
            else:
                one_d_exp = regional_mean(
                    generate_offdiagonals(obs_dic[region]), logged=logged)
            exp[region] = one_d_exp

    if outpath is None:
        outpath = obs_file.replace('obs.counts', 'regional_exp.txt')
    m2c.flatexp2text_regional(exp, outpath)
