import sys
sys.path.insert(0, '../modules2016')

import numpy as np
import glob
from generate_primermap import generate_primermap
from load_primer_counts import load_primer_counts
from generate_pixelmap_from_primermap import generate_pixelmap_from_primermap
from fragment_to_bin_filter import fragment_to_bin_filter
from arithmetic_mean_filter import arithmetic_mean_filter
from check_neighborhood_non_nan import check_neighborhood_non_nan
from wipe_unsmoothable_columns import wipe_unsmoothable_columns
from write_counts import write_counts
from write_pixelmap import write_pixelmap

def main():
    # load data
    print('loading data')
    trimmed_primermap = generate_primermap('input/primers_augmented_trimmed.bed')
    counts = glob.glob('input/*_joint_express.counts')
    trimmed_joint_express_counts = [load_primer_counts(counts[i],trimmed_primermap) 
                                    for i in range(len(sorted(counts)))]

    reps = [counts[i].split('/')[1].split('.')[0].split('_joint')[0] for i in range(len(sorted(counts)))]

    # establish bin mapping for all regions
    print('establishing pixelmap')
    pixelmap = {region: generate_pixelmap_from_primermap(trimmed_primermap[region],4000, region_name = region)
                for region in trimmed_joint_express_counts[0].keys()}

    # write output
    write_pixelmap(pixelmap,'output/trimmed_4kb_pixelmap.bed')

    # bin fragment level data
    print('binning and smoothing data')
    filter_function = arithmetic_mean_filter()
    binned_counts=[{region:fragment_to_bin_filter(np.log2(trimmed_joint_express_counts[i][region]+1), 
                                                   filter_function, pixelmap[region],
                                                   trimmed_primermap[region],
                                                   neighborhood_radius=8000) 
                      for region in trimmed_joint_express_counts[0].keys()}
                      for i in range(len(trimmed_joint_express_counts))]
    # wipe empty bins
    print('wiping empty bins')
    binned_counts = [wipe_unsmoothable_columns(binned_counts[i],trimmed_primermap, pixelmap,
                               16000) for i in range(len(binned_counts))]
    # write output
    for i in range(len(binned_counts)):
        write_counts(binned_counts[i],'output/Fig_3AK_4AB_5A_6BCG_Sup_Fig_3/%s_16kb_4kb_observed.counts'%reps[i],pixelmap)

if __name__ == "__main__":
    main()
