import numpy as np
from scipy.ndimage.filters import generic_filter

from flag_high_spatial_outliers import flag_high_spatial_outliers
from function_util import parallelize_regions


@parallelize_regions
def remove_high_spatial_outliers(counts, size=5, fold_threshold=8.0,
                                 overwrite_value='nan'):
    """
    Convenience function for removing high spatial outliers from counts
    matrices.

    Parameters
    ----------
    counts : np.ndarray
        The matrix to remove outliers from.
    size : int
        The size of the window to look in around each element when deciding if
        it is an outlier. Should be an odd integer.
    fold_threshold : float
        Elements will be flagged as outliers if they are greater than this
        number or greater than this many times the local median (as estimated
        using the window size in ``size``).
    overwrite_value : {'nan', 'zero', 'median'}
        The value to overwrite elements flagged as outliers with.

    Returns
    -------
    np.ndarray
        The input matrix with all spatial outliers overwritten.
    """
    # flag outliers
    flagged_indices = flag_high_spatial_outliers(
        counts, size=size, fold_threshold=fold_threshold)

    # overwrite outliers
    if overwrite_value == 'nan':
        counts[flagged_indices == 1] = np.nan
    elif overwrite_value == 'zero':
        counts[flagged_indices == 1] = 0
    elif overwrite_value == 'median':
        median_array = generic_filter(normalized, np.nanmedian, size=size)
        counts[flagged_indices == 1] = median_array[flagged_indices == 1]

    return counts
