import sys
sys.path.insert(0, '../modules2016')

import numpy as np
from load_pixelmap import load_pixelmap
from load_binned_counts import load_binned_counts
from propagate_nans import propagate_nans
from donut_filt import donut_filt
from lower_left_filt import lower_left_filt
from write_counts import write_counts

def main():
    # load data
    print('loading data')
    pixelmap = load_pixelmap('output/trimmed_4kb_pixelmap.bed')

    reps =                 [#'ScrKD-48Hr-ZC3C1',
                            #'ScrKD-48Hr-ZC3C3',
                            'ScrKD-78Hr-ZC3C1',
                            'ScrKD-78Hr-ZC3C2',
                            #'YY1KD-48Hr-ZC3C1',
                            #'YY1KD-48Hr-ZC3C2',
                            'YY1KD-78Hr-ZC3C1',
                            'YY1KD-78Hr-ZC3C2']

    observed_counts=[#load_binned_counts('output/counts/ScrKD-48Hr-ZC3C1_16kb_4kb_observed.counts',pixelmap=pixelmap),
                       #load_binned_counts('output/counts/ScrKD-48Hr-ZC3C3_16kb_4kb_observed.counts',pixelmap=pixelmap),
                       load_binned_counts('output/counts/ScrKD-78Hr-ZC3C1_16kb_4kb_observed.counts',pixelmap=pixelmap),
                       load_binned_counts('output/counts/ScrKD-78Hr-ZC3C2_16kb_4kb_observed.counts',pixelmap=pixelmap),
                       #load_binned_counts('output/counts/YY1KD-48Hr-ZC3C1_16kb_4kb_observed.counts',pixelmap=pixelmap),
                       #load_binned_counts('output/counts/YY1KD-48Hr-ZC3C2_16kb_4kb_observed.counts',pixelmap=pixelmap),
                       load_binned_counts('output/counts/YY1KD-78Hr-ZC3C1_16kb_4kb_observed.counts',pixelmap=pixelmap),
                       load_binned_counts('output/counts/YY1KD-78Hr-ZC3C2_16kb_4kb_observed.counts',pixelmap=pixelmap)]

    expected_counts=[#load_binned_counts('output/counts/ScrKD-48Hr-ZC3C1_16kb_4kb_global_expected.counts',pixelmap=pixelmap),
                       #load_binned_counts('output/counts/ScrKD-48Hr-ZC3C3_16kb_4kb_global_expected.counts',pixelmap=pixelmap),
                       load_binned_counts('output/counts/ScrKD-78Hr-ZC3C1_16kb_4kb_global_expected.counts',pixelmap=pixelmap),
                       load_binned_counts('output/counts/ScrKD-78Hr-ZC3C2_16kb_4kb_global_expected.counts',pixelmap=pixelmap),
                       #load_binned_counts('output/counts/YY1KD-48Hr-ZC3C1_16kb_4kb_global_expected.counts',pixelmap=pixelmap),
                       #load_binned_counts('output/counts/YY1KD-48Hr-ZC3C2_16kb_4kb_global_expected.counts',pixelmap=pixelmap),
                       load_binned_counts('output/counts/YY1KD-78Hr-ZC3C1_16kb_4kb_global_expected.counts',pixelmap=pixelmap),
                       load_binned_counts('output/counts/YY1KD-78Hr-ZC3C2_16kb_4kb_global_expected.counts',pixelmap=pixelmap)]

    # propagate nans between observed and expected counts dicts
    print('propagating nans')
    for i in range(len(observed_counts)):
        observed_counts[i],expected_counts[i] = propagate_nans(observed_counts[i],expected_counts[i])

    # get regional expected donut 
    print('calculating donut expected')
    donut_expected = [{region : donut_filt(observed_counts[i][region],expected_counts[i][region]) 
             for region in observed_counts[0].keys()} 
             for i in range(len(observed_counts))]
    
    # write output
    for i in range(len(donut_expected)):
        write_counts(donut_expected[i],'output/counts/%s_16kb_4kb_donut_expected.counts'%reps[i],pixelmap)
    
    #calculate lower left donut
    print('calculating lower left donut expected')
    lower_left_expected = [{region : lower_left_filt(observed_counts[i][region],expected_counts[i][region])
                           for region in observed_counts[0].keys()} for i in range(len(observed_counts))]
    # write output
    for i in range(len(lower_left_expected)):
        write_counts(lower_left_expected[i],'output/counts/%s_16kb_4kb_lower_left_donut_expected.counts'%reps[i],pixelmap)

    # calculate max(donut,ll)
    print('calculating max of donut and lower left donut expected')
    max_donut_ll = [{region: np.fmax(donut_expected[i][region],lower_left_expected[i][region])
                    for region in donut_expected[0].keys()}
                    for i in range(len(donut_expected))]

    # write output
    for i in range(len(max_donut_ll)):
        write_counts(max_donut_ll[i],'output/counts/%s_16kb_4kb_max_donut_ll_expected.counts'%reps[i],pixelmap)

    #calculate obs/exp                  
    print('calculating observed/expected ratio')
    obs_over_exp = [{region: np.subtract(observed_counts[i][region],max_donut_ll[i][region])
                     for region in max_donut_ll[0].keys()}
                     for i in range(len(max_donut_ll))]
    # write output
    for i in range(len(obs_over_exp)):
        write_counts(obs_over_exp[i],'output/counts/%s_16kb_4kb_obs_over_max_donut_ll.counts'%reps[i],pixelmap)
                

if __name__ == "__main__":
    main()
