import sys
sys.path.insert(0,'../modules2016/stage_1')

import numpy as np
import glob
from generate_primermap import generate_primermap
from load_primer_counts import load_primer_counts
from make_expected_matrix_fragment import make_expected_matrix_fragment
from polyfit_model_fragments import polyfit_model_fragments
from joint_express_normalize import joint_express_normalize
from write_counts import write_counts

def main():
    # load data
    print('loading data')
    trimmed_primermap = generate_primermap('output/primers_augmented_trimmed.bed')
    counts = glob.glob('output/qnormed_counts/*_conditional_qnormed.counts')
    qnormed_counts_superdict ={counts[i].split('/')[2].split('.')[0].split('_conditional')[0]:
                                         load_primer_counts(counts[i],trimmed_primermap)
                                         for i in range(len(sorted(counts)))}
    #reshape
    counts_list = []
    for rep in sorted(qnormed_counts_superdict.keys()):
        counts_list.append(qnormed_counts_superdict[rep])

    # precompute fragment expected model
    expected_counts = [make_expected_matrix_fragment(counts_list[i],trimmed_primermap,polyfit_model_fragments,is_global=False)
                        for i in range(len(counts_list))]


    # reshape
    reshaped_observed_counts = {region: [counts_list[i][region] for i in range(len(counts_list))]
                                for region in counts_list[0].keys()}
    reshaped_expected_counts = {region: [expected_counts[i][region] for i in range(len(counts_list))]
                                for region in counts_list[0].keys()}

    #joint normalize
    print('joint express normalizing')
    raw_balanced_counts,_,_ = joint_express_normalize(reshaped_observed_counts,
                                                                    reshaped_expected_counts)

    # reshape
    joint_expressed_counts = [{region: raw_balanced_counts[region][i] for region in counts_list[0].keys()}
                      for i in range(len(counts_list))]
    # write output
    for i in range(len(joint_expressed_counts)):
        write_counts(joint_expressed_counts[i],
                    'output/expressed_counts/%s_joint_express.counts'%(sorted(qnormed_counts_superdict.keys())[i]),trimmed_primermap)

if __name__ == "__main__":
    main()
