import sys
sys.path.insert(0,'../modules2016/stage_1')

import numpy as np
import glob
from copy import deepcopy
from generate_primermap import generate_primermap
from load_primer_counts import load_primer_counts
from generate_pixelmap_from_primermap import generate_pixelmap_from_primermap
from trim_primers import trim_primers
from trim_matrix import trim_matrix
from write_counts import write_counts
from write_primermap import write_primermap

def main():
    # load data
    print('loading data')
    primermap = generate_primermap('input/primers_augmented.bed')
    counts = glob.glob('input/*.counts')
    counts_superdict = {counts[i].split('/')[1].split('.')[0]: load_primer_counts(counts[i],primermap)
                        for i in range(len(sorted(counts)))}

    print('trimming bad primers')
    trimmed_primermap = {}
    trimmed_indices = {}
    trimmed_counts_superdict = {rep: {} for rep in counts_superdict.keys()}
    for region in primermap.keys():
        trimmed_primermap[region],trimmed_indices[region] =\
        trim_primers(primermap[region],
                    {rep:counts_superdict[rep][region]
                    for rep in counts_superdict.keys()},
                    min_sum=10.0,
                    min_frac=None)

    #remove Nestin, Olig, Oct4, Nanog, Gene-Desert regions
    total_trimmed_primermap=deepcopy(trimmed_primermap)
    for region in trimmed_primermap:
	trimmed_indices_region=[]
	if region == 'Nanog-V2' or region == 'Nestin' or region == 'Olig1-Olig2' or region == 'Oct4' or region == 'gene-desert':
	    for i in range(len(trimmed_primermap[region])):
	        total_trimmed_primermap[region].remove(trimmed_primermap[region][i])
                trimmed_indices_region.append(i)
    	    del total_trimmed_primermap[region]
    	trimmed_indices[region] = set(trimmed_indices_region)

    trimmed_primermap = total_trimmed_primermap
    print trimmed_primermap
    '''
    #handpick bad primers from Nanog
    nanog_trimmed_primermap = deepcopy(trimmed_primermap['Nanog-V2'])
    trimmed_nanog_indices = []
    for i in range(len(trimmed_primermap['Nanog-V2'])):
        if trimmed_primermap['Nanog-V2'][i]['number'] > 206:
            nanog_trimmed_primermap.remove(trimmed_primermap['Nanog-V2'][i])
            trimmed_nanog_indices.append(i)

    trimmed_primermap['Nanog-V2'] = nanog_trimmed_primermap
    trimmed_indices['Nanog-V2'] = set(trimmed_nanog_indices)

    #handpick bad primers from Olig1-Olig2
    olig_trimmed_primermap = deepcopy(trimmed_primermap['Olig1-Olig2'])
    trimmed_olig_indices = []        
    for i in range(len(trimmed_primermap['Olig1-Olig2'])):
        if trimmed_primermap['Olig1-Olig2'][i]['number'] == 152 or \
        trimmed_primermap['Olig1-Olig2'][i]['number'] == 193 or \
        trimmed_primermap['Olig1-Olig2'][i]['number'] == 219:
            olig_trimmed_primermap.remove(trimmed_primermap['Olig1-Olig2'][i])
            trimmed_olig_indices.append(i)

    trimmed_primermap['Olig1-Olig2'] = olig_trimmed_primermap        
    trimmed_indices['Olig1-Olig2'] = set(trimmed_olig_indices)
    '''
    print('trimming counts')
    for rep in counts_superdict.keys():
        for region in counts_superdict[rep].keys():
            trimmed_counts_superdict[rep][region] = trim_matrix(counts_superdict[rep][region],
                                                            trimmed_indices[region])
    
    # write output
    for rep in counts_superdict.keys():
        write_counts(trimmed_counts_superdict[rep],'output/trimmed_counts/%s_trimmed.counts' % rep,
                     trimmed_primermap)
        write_primermap(trimmed_primermap,'output/primers_augmented_trimmed.bed',
                        extra_column_names = ['length','% GC'])

if __name__ == "__main__":
    main()
