#!/usr/bin/env python
from sys import argv
import numpy as np
np.random.seed(1234)

array1 = argv[1]
array2 = argv[2]
out_array = argv[3]

# array1 = "protein_coding.filtered_by_expression.flank_windows.intersect.enhancer_extended.npy"
# array2 = "protein_coding.filtered_by_expression.flank_windows.intersect.dhss_extended_no_EP300_no_CTCF.npy"
# out_array = "protein_coding.filtered_by_expression.flank_windows.intersect.dhss_extended_no_EP300_no_CTCF_size_matched_to_enhancers.npy"

array1 = np.load(array1)
array2 = np.load(array2)

num_ones = int(array1.sum())
num_ones_to_zero_out = int(array2.sum()) - num_ones

nonzero_indices = np.nonzero(array2)
shuffle_index = np.random.permutation(np.arange(len(nonzero_indices[0])))
nonzero_indices_to_zero = (np.array([nonzero_indices[0][idx] for idx in shuffle_index][:num_ones_to_zero_out]),
                           np.array([nonzero_indices[1][idx] for idx in shuffle_index][:num_ones_to_zero_out]))

array2[nonzero_indices_to_zero] = 0

assert(array1.sum() == array2.sum())

np.save(out_array, array2)
