import sys
sys.path.insert(0, '../modules2016')

import numpy as np
import glob
from load_pixelmap import load_pixelmap
from load_binned_counts import load_binned_counts
from propagate_nans import propagate_nans
from donut_filt import donut_filt
from lower_left_filt import lower_left_filt
from write_counts import write_counts

def main():
    # load data
    print('loading data')
    pixelmap = load_pixelmap('output/trimmed_4kb_pixelmap.bed')
    obs_counts = glob.glob('output/Fig_3AK_4AB_5A_6BCG_Sup_Fig_3/*_16kb_4kb_observed.counts')
    exp_counts = glob.glob('output/counts/*_16kb_4kb_global_expected.counts')

    observed_counts =[load_binned_counts(obs_counts[i],pixelmap=pixelmap) for i
                      in range(len(sorted(obs_counts)))]

    expected_counts = [load_binned_counts(exp_counts[i],pixelmap=pixelmap) for i
                       in range(len(sorted(exp_counts)))]

    reps = [obs_counts[i].split('/')[2].split('.')[0].split('_16kb')[0]
                        for i in range(len(sorted(obs_counts)))]

    # propagate nans between observed and expected counts dicts
    print('propagating nans')
    for i in range(len(observed_counts)):
        observed_counts[i],expected_counts[i] = propagate_nans(observed_counts[i],expected_counts[i])

    # get regional expected donut 
    print('calculating donut expected')
    donut_expected = [{region : donut_filt(observed_counts[i][region],expected_counts[i][region]) 
             for region in observed_counts[0].keys()} 
             for i in range(len(observed_counts))]
    
    # write output
    for i in range(len(donut_expected)):
        write_counts(donut_expected[i],'output/Fig_3C_Sup_Fig_4/%s_16kb_4kb_donut_expected.counts'%reps[i],pixelmap)
    
    #calculate lower left donut
    print('calculating lower left donut expected')
    lower_left_expected = [{region : lower_left_filt(observed_counts[i][region],expected_counts[i][region])
                           for region in observed_counts[0].keys()} for i in range(len(observed_counts))]
    # write output
    for i in range(len(lower_left_expected)):
        write_counts(lower_left_expected[i],'output/Fig_3D_Sup_Fig_5/%s_16kb_4kb_lower_left_donut_expected.counts'%reps[i],pixelmap)

    # calculate max(donut,ll)
    print('calculating max of donut and lower left donut expected')
    max_donut_ll = [{region: np.fmax(donut_expected[i][region],lower_left_expected[i][region])
                    for region in donut_expected[0].keys()}
                    for i in range(len(donut_expected))]

    # write output
    for i in range(len(max_donut_ll)):
        write_counts(max_donut_ll[i],'output/Fig_3E_Sup_Fig_6/%s_16kb_4kb_max_donut_ll_expected.counts'%reps[i],pixelmap)
    '''
    #calculate obs/exp                  
    print('calculating observed/expected ratio')
    obs_over_exp = [{region: np.subtract(observed_counts[i][region],max_donut_ll[i][region])
                     for region in max_donut_ll[0].keys()}
                     for i in range(len(max_donut_ll))]
    # write output
    for i in range(len(obs_over_exp)):
        write_counts(obs_over_exp[i],'output/counts/%s_16kb_4kb_obs_over_max_donut_ll.counts'%reps[i],pixelmap)
    '''            

if __name__ == "__main__":
    main()
