# -*- coding: utf-8 -*-
"""
Created on Wed Jun 22 15:13:59 2016

@author: Owner
"""

import os

import pdb

import matplotlib as mpl
mpl.use('Agg')
import matplotlib.pyplot as plt
import numpy as np
import scipy.stats as stats

import sys
sys.path.insert(0, '../modules2016')

from load_pixelmap import load_pixelmap
from load_binned_counts import load_binned_counts
from make_compiled_list_for_2i_thresholding import \
make_compiled_list_for_2i_thresholding
from classify_2i_interaction_scores import \
classify_2i_interaction_scores
from reshape_2i_loops_for_clustering import \
reshape_2i_loops_for_clustering
from make_annotationmaps import make_annotationmaps
from Calculating_CTCF_Orientation_Enrichments import ORIENTATION_count_near_intersections
from Calculating_CTCF_Orientation_Enrichments import compute_p_values

def main():
    # prepare output directory
    directory = 'output'
    if not os.path.exists(directory):
        os.makedirs(directory)
    
    pixelmap = load_pixelmap('input/trimmed_4kb_pixelmap.bed')
    
    interaction_scores=[load_binned_counts('input/JBR2_ES1_S3_GH_5C_12_11_2015_16kb_4kb_obs_over_exp_int_scores_diag_cluster_trimmed.counts',
                        pixelmap=pixelmap),
                        load_binned_counts('input/JBR2_ES2_S2_GH_5C_12_11_2015_16kb_4kb_obs_over_exp_int_scores_diag_cluster_trimmed.counts',
                        pixelmap=pixelmap),
                        load_binned_counts('input/pNPC1_S4_10_24_2015_16kb_4kb_obs_over_exp_int_scores_diag_cluster_trimmed.counts',
                        pixelmap=pixelmap),
                        load_binned_counts('input/pNPC2_S3_10_24_2015_16kb_4kb_obs_over_exp_int_scores_diag_cluster_trimmed.counts',
                        pixelmap=pixelmap),
                        load_binned_counts('input/v652i1_S2_10_24_2015_16kb_4kb_obs_over_exp_int_scores_diag_cluster_trimmed.counts',
                        pixelmap=pixelmap),
                        load_binned_counts('input/v652i2_S1_10_24_2015_16kb_4kb_obs_over_exp_int_scores_diag_cluster_trimmed.counts',
                        pixelmap=pixelmap)]

    print('compiling data')
    compiled_interaction_scores =  make_compiled_list_for_2i_thresholding(interaction_scores,pixelmap,
                                                                          distance_span='start-to-start')
    #classify loops
    print ('classifying loops')
    raw_loops = classify_2i_interaction_scores(compiled_interaction_scores,
                      background_thresh=-10*np.log2(float(0.8)),
                      sig_thresh=-10*np.log2(float(0.165)),
                      border_constit_thresh=-10*np.log2(float(0.0625)),
                      constit_thresh=-10*np.log2(float(0.0625)),
                      difference_thresh=30.2)

    # reshape loops for clustering
    print('reshaping loops dict')
    loops = reshape_2i_loops_for_clustering(raw_loops,interaction_scores[0].keys())
    
    annotationmap = make_annotationmaps(pixelmap, directory = 'annotations',add_wildcard=True)

    annotationmap['Merged_CTCF_Multiple_Consensus'] = annotationmap['MERGEDNPCplusSERUMCTCF_more_than_once_intersecting_consensus_6_20_2016']
    annotationmap['Merged_CTCF_Reverse'] = annotationmap['MERGEDNPCplusSERUMCTCF_negative_strand_intersecting_consensus_6_20_2016']
    annotationmap['Merged_CTCF_NoConsensus'] = annotationmap['MERGEDNPCplusSERUMCTCF_not_intersecting_consensus_6_20_2016']
    annotationmap['Merged_CTCF_Forward'] = annotationmap['MERGEDNPCplusSERUMCTCF_positive_strand_intersecting_consensus_6_20_2016']
    
    del annotationmap['MERGEDNPCplusSERUMCTCF_more_than_once_intersecting_consensus_6_20_2016']
    del annotationmap['MERGEDNPCplusSERUMCTCF_negative_strand_intersecting_consensus_6_20_2016']
    del annotationmap['MERGEDNPCplusSERUMCTCF_not_intersecting_consensus_6_20_2016']
    del annotationmap['MERGEDNPCplusSERUMCTCF_positive_strand_intersecting_consensus_6_20_2016']

    #print annotationmap

    CTCF_Count_Dicts = ORIENTATION_count_near_intersections(loops, annotationmap, 'Merged_CTCF_Forward', 'Merged_CTCF_Reverse', margin=1)  
    count_orientation_intersection_dict = CTCF_Count_Dicts[0]
    orientation_intersection_dict = CTCF_Count_Dicts[1]
    count_loops_w_Conflicting_CTCF_dict = CTCF_Count_Dicts[2]
    loops_w_Conflicting_CTCF_dict = CTCF_Count_Dicts[3]
    count_loops_w_OneSide_CTCF_dict = CTCF_Count_Dicts[4]
    loops_w_OneSide_CTCF_dict = CTCF_Count_Dicts[5]
    count_loops_w_NO_CTCF_dict = CTCF_Count_Dicts[6]
    loops_w_NO_CTCF_dict = CTCF_Count_Dicts[7]    
    count_loops_w_NO_Consensus_dict = CTCF_Count_Dicts[8] 
    loops_w_NO_Consensus_dict = CTCF_Count_Dicts[9] 
    count_loops_w_Multiple_Consensus_dict = CTCF_Count_Dicts[10]
    loops_w_Multiple_Consensus_dict = CTCF_Count_Dicts[11]
    count_loops_w_Unique_CTCF_dict = CTCF_Count_Dicts[12]
    
    loop_types = ['es_npc', 'background']
    margin_one_p_value_dict = compute_p_values(count_orientation_intersection_dict, loops, loop_types)
    #print margin_one_p_value_dict
    # print some pvalues
    pvals = {}
    for orientation in margin_one_p_value_dict.keys():
        pvals[orientation] = {}
        for loop_type in loop_types:
            p_values = margin_one_p_value_dict[orientation][loop_type]
            pvals[orientation][loop_type] = min(p_values['less'], p_values['greater'])
    
    output = open('output/NPC_SERUM_Contingency_tables_w_PVals'+'.txt','w')
    for loop_type in loop_types:
        total_loops = int(len([x for region in loops[loop_type].keys() for x in loops[loop_type][region]]))
        table = ([['Loop Type', loop_type], ['Total Loop Number', total_loops], ['Number of Loops with No CTCF', count_loops_w_NO_CTCF_dict[loop_type]], ['Number of Loops with No Consensus CTCF', count_loops_w_NO_Consensus_dict[loop_type]], ['Number of Loops with Multiple Consensus CTCF', count_loops_w_Multiple_Consensus_dict[loop_type]], ['Number of Loops with CTCF on Only One Side', count_loops_w_OneSide_CTCF_dict[loop_type]], ['Number of Loops with Conflicting CTCF Orientation(s)', count_loops_w_Conflicting_CTCF_dict[loop_type]], ['Number of Loops with Unique CTCF Orientation(s)', count_loops_w_Unique_CTCF_dict[loop_type]], ['Number of Total Convergent CTCF Loops', count_orientation_intersection_dict['Convergent']['total'][loop_type]], ['Number of Total Divergent CTCF Loops', count_orientation_intersection_dict['Divergent']['total'][loop_type]], ['Number of Total Same_Forward CTCF Loops', count_orientation_intersection_dict['Same_Forward']['total'][loop_type]], ['Number of Total Same_Reverse CTCF Loops', count_orientation_intersection_dict['Same_Reverse']['total'][loop_type]], ['Convergent Enrichment PValue', pvals['Convergent'][loop_type]], ['Divergent Enrichment PValue', pvals['Divergent'][loop_type]], ['Same Forward Enrichment PValue', pvals['Same_Forward'][loop_type]], ['Same Reverse Enrichment PValue', pvals['Same_Reverse'][loop_type]]])
        for r in range(len(table)):          
            print >> output, "%s\t%s" % (table[r][0], str(table[r][1]))
    output.close()

if __name__=="__main__":
    main()
