import numpy as np

from flatten_counts_single_region_geometric import flatten_counts_single_region
from kr_balance import kr_balance
from balance_matrix import balance_matrix
from gmean import gmean


def kr_balance_counts(counts, max_iter=3000, retain_scale=True):
    """
    Convenience function for applying KR balancing to a counts matrix.

    Parameters
    ----------
    counts : np.ndarray
        The counts matrix to balance.
    max_iter : int
        The maximum number of iterations to try.
    retain_scale : bool
        Pass True to rescale the results to the scale of the original matrix
        using a ratio of geometric means.

    Returns
    -------
    Tuple[np.ndarray, np.ndarray, np.ndarray]
        The first array contains the balanced matrix. The second contains the
        bias vector. The third contains the residual.
    """
    bias_vector, errs = kr_balance(counts, max_iter=max_iter)
    balanced = balance_matrix(counts, bias_vector)
    if retain_scale:
        # temporary heuristic for scaling the kr results to the original scale
        factor = gmean(np.array(flatten_counts_single_region(counts))) *\
                 len(counts)
        balanced *= factor
        bias_vector *= np.sqrt(factor)
    balanced[~np.isfinite(counts)] = np.nan
    return balanced, bias_vector[:, 0], errs

