import numpy as np
from scipy.stats.mstats import gmean
from statsmodels.nonparametric.smoothers_lowess import lowess

import warnings
# following imports are just for direct counts ouput file
try:
    import sys
    import glob
    sys.path.append('/project/jcreminslab/dagillis_projects/5C_peak_filter/'
                    'stand_alone_modules')
    import mat2counts as m2c
    from lib5c.parsers.counts import load_counts
except:
    warnings.warn('Import incomplete; only use the methods global_mean'
                  'and lowess_smooth_expected')


def get_flattened_values(region_mats, logged=True):
    '''
    Merges all data from the upper triangular part of each region.

    Parameters
    ----------
    region_mats : dictionary of np.ndarray
        Dictionary containing the observed matrix data of each region.
        The keys are regions.

    logged : boolean
        Pass False if unlogged data to avoid the putting zeros into the return
        data. This is done to prevent the geometric mean from returning zero.

    Returns
    ----------
    counts_at_dist : dictionary of lists
        Dictionary containing list of all observed counts that occurred at
        each distance. Keys are the integer values of each distance.

    data : list of tuples
        List containing each data point.
        First value of each entry is the distance. Second value is the count.
    '''
    counts_at_dist = {}
    data = []

    for region, mat in region_mats.iteritems():
        for i in range(len(mat)):
            for j in range(i+1):
                dist = i - j
                if dist not in counts_at_dist:
                    counts_at_dist[dist] = []
                if not np.isnan(mat[i, j]):
                    if not logged and mat[i, j] == 0:
                        continue
                    counts_at_dist[dist].append(mat[i, j])
                    data.append((dist, mat[i, j]))

    data = np.asarray(data)
    return counts_at_dist, data


def get_means_at_dist(counts_at_dist, logged=True):
    '''
    Compute the geometric mean at each distance.

    Parameters
    ----------
    counts_at_dist : dictionary of lists
        Dictionary containing list of all observed counts that occurred at
        each distance. Keys are the integer values of each distance.

    logged : boolean
        Pass True if values are logged to calculate the arithmetic mean
        in log space.

    Returns
    ----------
    means : dictionary of floats
        Dictionary containing the geometric mean at each distance.
        Keys are integer distances.

    '''
    means = {}
    for dist in counts_at_dist:
        if len(counts_at_dist) and any(counts_at_dist):
            if logged:
                means[dist] = np.mean(counts_at_dist[dist])
            else:
                means[dist] = gmean(counts_at_dist[dist])
        else:
            means[dist] = 0
    return means


def global_mean(region_mats, logged=True):
    '''
    Compute and return the 1D empirical global mean.

    Parameters
    ----------
    region_mats : dictionary of np.ndarray
        Dictionary containing the observed matrix data of each region.
        The keys are regions.

    logged : boolean
        Pass True if values are logged to calculate the arithmetic mean
        in log space.

    Returns
    ----------
    means : np.array
        Numpy array of the geometric mean at each distance.
    '''
    counts_at_dist, _ = get_flattened_values(region_mats, logged=logged)
    means = get_means_at_dist(counts_at_dist, logged=logged)
    means = np.asarray([means[dist] for dist in means])
    return means


def lowess_smooth_expected(region_mats, logged=True, insertion=None):
    '''
    Compute and return the lowess smoothed empirical global mean.

    Parameters
    ----------
    region_mats : dictionary of np.ndarray
        Dictionary containing the observed matrix data of each region.
        The keys are regions.

    logged : boolean
        Pass True if values are logged to calculate the arithmetic mean
        in log space.

    insertion : float
        Fraction of the 1D global mean at which to begin lowess smoothing.
        If None, the function will use the default values. Note that the
        default insertion points differ for logged and unlogged data (1/3
        and 1/2, respectively). This is directly descended from the Hb and A
        pipelines.

    Returns
    ----------
    joined_fit : list
        Numpy array containing the lowess smoothed geometric mean at each
        distance.
    '''
    counts_at_dist, data = get_flattened_values(region_mats, logged=logged)
    means = get_means_at_dist(counts_at_dist, logged=logged)

    mean_split = [means[x] for x in means if np.isfinite(x)]
    global_size = len(mean_split)

    # determine location at which to begin lowess smoothing
    if insertion is None:
        if logged:
            insertion = 1/3.
        else:
            insert_point = 0.5
    key_index = int(global_size * insertion)

    # include data from the key index onward
    filtered_data = np.asarray(filter(lambda x: x[0] >= key_index, data))
    fit = lowess(filtered_data[:, 1], filtered_data[:, 0], frac=0.8, it=3)

    # filter the fit to exclude the values that were at the key index
    filtered_fit = np.asarray(filter(lambda x: x[0] > key_index, fit))
    # construct an array that will represent the joined fit
    joined_fit = np.zeros(global_size)
    for i in range(key_index + 1):
        joined_fit[i] = means[i]
    for i in range(key_index + 1, global_size):
        # check if distance is in lowess fit
        query_result = filter(lambda x: x[0] == i, filtered_fit)
        if query_result:
            joined_fit[i] = query_result[0][1]
        else:
            print 'non_query_result', query_result, i
            joined_fit[i] = means[i]
    return joined_fit


def expected_matrix(matrix_dic, logged=True, insertion=None):
    '''
    Compute and return the expected matrix from the lowess smoothed empirical
    global mean.
    '''
    one_d_exp = lowess_smooth_expected(matrix_dic,
                                       logged=logged, insertion=insertion)

    max_len = len(one_d_exp)
    exp_mat = np.zeros((max_len, max_len))
    for i, exp in enumerate(one_d_exp):
        for j in xrange(max_len - i):
            exp_mat[j, j+i] = exp
            exp_mat[j+i, j] = exp_mat[j, j+i]
    return exp_mat


def expected_dictionary(obs_dic, logged=True, insertion=None):
    '''
    Compute and return a dictionary of global expected matrices for each
    region.
    '''
    exp_mat = expected_matrix(obs_dic, logged=logged, insertion=insertion)
    # create dictionary of expected matrices
    exp_dic = {}
    for region in obs_dic:
        obs_shape = obs_dic[region].shape
        reg_exp_mat = np.zeros(obs_shape)
        for i in xrange(obs_shape[0]):
            for j in xrange(obs_shape[1]):
                reg_exp_mat[i, j] = exp_mat[i, j]
        exp_dic[region] = reg_exp_mat
    return exp_dic


def make_counts_file(obs_file, outpath=None):
    '''
    Creates the corresponding global expected counts file given an
    observed counts file
    '''
    obs_dic = load_counts(obs_file)
    print 'generating expected'
    exp_dic = expected_dictionary(obs_dic)
    if outpath is None:
        outpath = obs_file.replace('obs', 'global_exp')
    print 'saving counts file'
    print 'outpath: {}'.format(outpath)
    m2c.dic2counts(exp_dic, outpath)


def obs_list_to_counts_files(obs_list, outpaths=None):
    '''
    Makes expected counts for each observed file in the input list
    '''
    if outpaths is None:
        outpaths = [None for _ in obs_list]
    for i, obs_file in enumerate(obs_list):
        print obs_file, '{} % done'.format(i/float(len(obs_list)) * 100)
        make_counts_file(obs_file, outpaths[i])


def write_expected_out(obs_file, outpath, ty='lowess'):
    '''
    Writes the expected out as a flattened expected file.
    The format of this is similar to, but not quite the same as .scaled files.
    The purpose of this format is to save space and allow easier import as
    1D expected files.

    Format of output txt file:
    # Region 1
    value_at_0_loci_distance
    value_at_1_loci_distance
    # Region 2
    etc.
    '''
    obs_dic = load_counts(obs_file)
    if ty == 'lowess':
        exp = lowess_smooth_expected(obs_dic)
    else:
        exp = global_mean(obs_dic)
    if outpath is None:
        outpath = obs_file.replace('obs.counts', 'global_exp.txt')
    m2c.flatexp2text(exp, obs_dic, outpath)


def main():
    '''
    Given a path to directory, creates the corresponding global expected files
    for all observed counts files.

    WARNING: Use of this main function is deprecated.
    It is presently kept only for compatability reasons (2I paper,
    specifically). This should not be relied on for future use.
    '''
    counts_path = sys.argv[1]
    obs_list = glob.glob(counts_path + '*obs.counts')
    obs_list_to_counts_files(obs_list)


if __name__ == '__main__':
    main()
