import sys
import numpy as np
#from flatten_counts_single_region_geometric import flatten_counts_single_region

def get_obs_colorscale_across_reps(counts_superdict,min_function=np.min,
                                   max_function=lambda x : np.percentile(x, 98)):
    """
    Calculates all regional colorscales across replicates

    Parameters
    ----------
    counts_superdict : Dict[str, Dict[str, np.ndarrray]]
        A superdict where the outer keys are the replicate names, the inner keys
        are region names, and the values are the corresponding observed counts
        matrix
    min_function : function
        The function that will compute the min for each replicate. The default
        is np.min
    max_function : function
        The function that will compute the max for each replicate. The default
        is a lambda expression of np.percentile where x is a placeholder for the
        the observed counts matrix and the second arg is
        the upper percentile. For more on lambda expressions, see
        https://docs.python.org/2/reference/expressions.html#lambda

    Returns
    -------
    Dict[str, tuple]
    The keys are the region names and the values are the upper and lower limits
    of the colorscale with respect to all replicates

    """
    min_and_max_observed_counts = {region: {rep:[min_function(counts_superdict[rep][region][np.isfinite(counts_superdict[rep][region])]),
                                                 max_function(counts_superdict[rep][region][np.isfinite(counts_superdict[rep][region])])]
                                            for rep in counts_superdict.keys()}
                                    for region in counts_superdict[counts_superdict.keys()[0]].keys()}

    regional_min_observed = {region:[] for region in counts_superdict[counts_superdict.keys()[0]].keys()}
    regional_max_observed = {region:[] for region in counts_superdict[counts_superdict.keys()[0]].keys()}
    
    for region in counts_superdict[counts_superdict.keys()[0]].keys():
        for rep in counts_superdict.keys():
            regional_min_observed[region].append(min_and_max_observed_counts[region][rep][0])
            regional_max_observed[region].append(min_and_max_observed_counts[region][rep][1])

    return {region:(np.mean(regional_min_observed[region]),np.mean(regional_max_observed[region])) 
            for region in counts_superdict[counts_superdict.keys()[0]].keys()}

