import numpy as np
import scipy.stats as stats

from function_util import parallelize_regions


@parallelize_regions
def extract_queried_counts(regional_counts, regional_primermap):
    """
    Starting from a square, symmetric counts matrix containing fragment-level
    contact information, return a non-square, non-symmetric matrix where the
    forward fragments (fragments to which 5C primers were designed to the +
    strand of the genome) sit in the rows of the matrix while the reverse
    fragments sit in the columns. This restricts the input matrix to only the
    pairwise contacts that were actually queried by the 5C assay.

    Parameters
    ----------
    regional_counts : np.ndarray
        The classic square, symmetric counts matrix for this region.
    regional_primermap : List[Dict[str, Any]]
        The primermap describing the fragments in this region. It must contain a
        'strand' metadata key so that ``regional_primermap[i]['strand']`` is
        ``'+'`` when the fragment was targeted by a forward-direction primer and
        ``'-'`` otherwise. This information is necessary to distinguish the
        forward and reverse fragments in the region.

    Returns
    -------
    np.ndarray
        The queried counts matrix, as described above.
    """
    # fill in forward and reverse primer lists
    for_primer_indices = []
    rev_primer_indices = []
    for i in range(len(regional_primermap)):
        if regional_primermap[i]['strand'] == '+':
            for_primer_indices.append(i)
        else:
            rev_primer_indices.append(i)

    # initialize queried counts matrix for this region to all zeros
    queried_counts = np.zeros([len(for_primer_indices),
                               len(rev_primer_indices)],
                              dtype=float)

    # fill in queried counts matrix for this region
    for i in range(len(for_primer_indices)):
        for j in range(len(rev_primer_indices)):
            queried_counts[i, j] = regional_counts[for_primer_indices[i],
                                                   rev_primer_indices[j]]

    return queried_counts