from tempfile import NamedTemporaryFile
from glob import glob
import sys
import os.path
import numpy as np
import pandas as pd
import threading
import subprocess
import itertools

def run_anat(terminal_filename, network_filename, anchor):
    tissue = terminal_filename.split('.')[0]
    out_folder = network_filename + '_results'
    output_filename = "%s/%s.%s_%s" % (out_folder, anchor, tissue, os.path.basename(network_filename))
    terminals = None
    with open(terminal_filename, 'r') as f:
        terminals = f.readlines()

    with NamedTemporaryFile(prefix='set', delete=False) as f:
        for t in terminals:
            f.write("%s\t%s" % (anchor, t))

        f.flush()

        template = "/home/bnet/atiasnir/anat/steiner/steinprt -f . -n %(network_filename)s -c 0 -b %(alpha).2f -s %(terminal_filename)s -r %(output_filename)s -l 0.75"
        cmdline = template % {'terminal_filename': f.name,
                              'network_filename': network_filename,
                              'alpha': 0.25,
                              'output_filename': output_filename }

        sys.stderr.write(cmdline + "\n")

        subprocess.call(cmdline.split(' '))
        #proc = subprocess.Popen(cmdline.split(' '), stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
        #out, err = proc.communicate()
        #print out
        #print err
        #print "--"

def _get_cores():
    return int(subprocess.check_output('grep core\ id /proc/cpuinfo | wc -l', shell=True))

def pmap(func, arr, cores=_get_cores()):
    chunk = -(-len(arr)//cores)
    jobs = np.split(arr, range(chunk, len(arr), chunk))

    class PmapThread(threading.Thread):
        def __init__(self, todo, *args, **kwargs):
            super(PmapThread,self).__init__(*args, **kwargs)
            self.todo = todo
            self.result = None

        def run(self):
            self.result = map(func, self.todo)


    threads = [PmapThread(x) for x in jobs]
    for t in threads:
        t.start()

    for t in threads:
        t.join()

    return itertools.chain.from_iterable(t.result for t in threads)


def get_tfs(network_filename):
    d = pd.read_table(network_filename, header=None, names=('inta', 'intb', 'conf', 'flag'))
    regulatory_mask = ((d.conf-0.6).abs()<10e-4) & (d.flag == 1)
    return d.inta[regulatory_mask].unique()


network_filename = '/home/bnet/atiasnir/data/dror/networks/H_sapiens-ppi-pdna.net'
terminal_filename = 'Tumor_colon.terminals.txt'

tfs = get_tfs(network_filename)

random_networks = glob('random/rnd_*/random_*_human.integrated.net')

for r in random_networks:
    mydir = r + '_results'
    if not os.path.exists(mydir):
        os.mkdir(mydir)

lala= pmap(lambda x: run_anat(terminal_filename, x[1], x[0]),
           list(itertools.product(tfs, random_networks)))
