import os
import sys
import random
import json
import numpy as np
from itertools import tee, izip, combinations, chain
import datetime
import cPickle as pickle
#import pickle
import re
import tempfile
import psutil

import subprocess

import meminfo
import commands
import logging

logger = logging.getLogger(__name__)

#--------------------------------------------------------------------------
# os
#--------------------------------------------------------------------------
def memusage(tag=None):
  def inmb(x): return 1. * x / 2 ** 20
  proc = psutil.Process(os.getpid())
  meminfo = proc.memory_info()
  logger.info('{0} rss:{1:4.4f}mb, vmem:{2:4.4f}mb'.format(
    tag if tag != None else '',
    inmb(meminfo.rss),
    inmb(meminfo.vms),
  ))
  #print '{0} rss:{1:4.4f}mb, vmem:{2:4.4f}mb'.format(
  #  tag if tag != None else '',
  #  inmb(meminfo.rss),
  #  inmb(meminfo.vms),
  #)

def mkdir_p(dir):
  if not os.path.isdir(dir):
    os.makedirs(dir)

def mktmpdir(prefix='tmp'):
  path = tempfile.mkdtemp(dir='.', prefix=prefix)
  return path

class cd: 
  def __init__(self, newPath):
    self.newPath = newPath

  def __enter__(self):
    self.savedPath = os.getcwd()
    os.chdir(self.newPath)

  def __exit__(self, etype, value, traceback):
    os.chdir(self.savedPath)

def concatFiles(input_list, output_path):
  with open(output_path, 'w') as outf:
    for path in input_list:
      with open(path) as inf:
        for line in inf:
          outf.write(line)

  return

#--------------------------------------------------------------------------
# iterator manipulation
#--------------------------------------------------------------------------
def peek(gen):
  peek = gen.next()
  gen = pushback(peek, gen)
  return (peek, gen)

def pushback(el, gen):
  gen = chain([el], gen)
  return gen

#--------------------------------------------------------------------------
# bed manipulation (FIXME put this somewhere else)
#--------------------------------------------------------------------------
# dump bed files denoting realigned regions
def writeBedFile(
  regions_map,
  out_fname,
  score_func=len,
  append=False,
):
  mod = 'a' if append else 'w'
  with open(out_fname, mod) as f:
    if not append:
      f.write('#chrom\tchromStart\tchromEnd\tname\tscore\n')
    for chrName in sorted(regions_map.keys()):
      # FIXME remove
      #if '_' in chrName:
      #  continue
      for (begin, end, val_list) in regions_map[chrName].getregions():
        f.write('{0}\t{1}\t{2}\t-\t{3}\n'.format(
          # FIXME remove
          #chrName.strip().strip('chr'),
          chrName.strip(),
          begin,
          end,
          score_func(val_list),
        ))
  return

#--------------------------------------------------------------------------
# pickle
#--------------------------------------------------------------------------
def writePickle(fname, obj):
  f = open(fname,'w')
  pickle.dump(
    obj,
    f,  
    pickle.HIGHEST_PROTOCOL
  )
  f.close()

def loadPickle(fname):
  f = open(fname,'r')
  obj = pickle.load(f)
  f.close()
  return obj 

#--------------------------------------------------------------------------
# json
#--------------------------------------------------------------------------
# load json into ascii rather than default unicode
def jsonLoadASCII(path):
  def _decode_list(data):
    rv = []
    for item in data:
      if isinstance(item, unicode):
        item = item.encode('utf-8')
      elif isinstance(item, list):
        item = _decode_list(item)
      elif isinstance(item, dict):
        item = _decode_dict(item)
      rv.append(item)
    return rv
  
  def _decode_dict(data):
    rv = {}
    for key, value in data.iteritems():
      if isinstance(key, unicode):
        key = key.encode('utf-8')
      if isinstance(value, unicode):
        value = value.encode('utf-8')
      elif isinstance(value, list):
        value = _decode_list(value)
      elif isinstance(value, dict):
        value = _decode_dict(value)
      rv[key] = value
    return rv

  try:
    with open(path) as f:
      return json.load(f, object_hook=_decode_dict)
  except Exception as e:
    logger.error('json loading error on file {0}'.format(path))
    logger.error(str(e))
    sys.exit(1)

#--------------------------------------------------------------------------
# random
#--------------------------------------------------------------------------

def chooseWeightedDict(d):
    assert d != {}
    items = d.items()
    weights = []
    for (_, score) in items:
      weights.append(score)
    total = sum(weights)
    key = random.uniform(0, total)
    runningTotal = 0.0 
    for i, weight in enumerate(weights):
      runningTotal += weight
      if runningTotal > key:
        return items[i]
    raise Exception('Should not reach here')

# create an enum type for fast indexing into arrays
# NOTE this is much faster than indexing based on strings
def enum(*sequential, **named):
  enums = dict(zip(sequential, range(len(sequential))), **named)
  reverse = dict((value, key) for key, value in enums.iteritems())
  enums['reverse'] = reverse
  return type('Enum', (), enums)

# iterator for iterating over a list pairwise, should return (current,
# next) in one pair
def pairwise(iterable):
  a, b = tee(iterable)
  next(b, None)
  return izip(a, b)

