import numpy as np
import matplotlib
matplotlib.use('Agg')
matplotlib.rc('xtick', labelsize=11.5) 
matplotlib.rc('ytick', labelsize=11.5) 
matplotlib.rc('ytick', labelsize=11.5) 
#font = {
#  'family' : 'serif',
#  'serif'  : ['times new roman'],
#  'size'   : 9,
#}
#matplotlib.rc('font', **font)
import pylab
import matplotlib.pyplot as pyplot
import matplotlib.ticker as ticker
from matplotlib.ticker import FuncFormatter
import datetime
from matplotlib import dates
from matplotlib import cm

import os
import sys
import logging

import multiprocessing as mp

from itertools import izip
from collections import Counter
from itertools import chain as it__chain
from itertools import repeat as it__repeat

from scipy.stats.mstats import mquantiles

logger = logging.getLogger(__name__)

# maintain plot count
plotCount = 0

dpiRes = 200

# wrap for memory leak issues within matplotlib
def plotHistograms(counter_list, userParams):
  global plotCount
  plotCount += 1

  def worker_func():
    __plotHistograms__(counter_list, userParams)

  pp = mp.Process(target=worker_func)
  pp.daemon = True
  pp.start()
  pp.join()


def plot(
  line_list,
  userParams,
):

  global plotCount
  plotCount += 1

  #colormap = pyplot.cm.gist_ncar
  #colorcycle = [colormap(i) for i in np.linspace(0, 0.9, len(dist_list))]
  #pyplot.gca().set_color_cycle(colorcycle) 
  #matplotlib.axes.set_default_color_cycle(colorcycle)

  params = {
    'xlabel' : 'x value',
    'ylabel' : 'y value',
    'fname' : 'fig',
    'title' : None,
  }

  params.update(userParams)

  fig = pyplot.figure()
  fig.subplots_adjust(bottom=0.4)
  sub = fig.add_subplot(111)
  if params['title']:
    pyplot.title(params['title'])

  pyplot.xlabel(params['xlabel'])
  pyplot.ylabel(params['ylabel'])

  for (name, xs, ys) in line_list:
    sub.plot(
      xs, 
      ys,
      label=name,
      linewidth=2,
    )

  pyplot.legend()
  fname = '%s%d.png' % (params['fname'], plotCount)
  logger.info('saving {0}'.format(fname))
  pyplot.savefig(fname, dpi = (dpiRes))
  # NOTE need these to free up memory
  fig.clf()
  #fig.close()
  pyplot.clf()
  pyplot.close()
  
def plot2D(
  xs,
  ys,
  userParams,
):

  global plotCount
  plotCount += 1

  params = {
    'xlabel' : 'x value',
    'ylabel' : 'y value',
    'fname' : 'fig',
  }

  params.update(userParams)

  fig = pyplot.figure()
  fig.subplots_adjust(bottom=0.4)
  sub = fig.add_subplot(111)

  pyplot.xlabel(params['xlabel'])
  pyplot.ylabel(params['ylabel'])

  sub.scatter(xs, ys)

  fname = '%s%d.png' % (params['fname'], plotCount)
  logger.info('saving {0}'.format(fname))
  pyplot.savefig(fname, dpi = (dpiRes))
  # NOTE need these to free up memory
  fig.clf()
  #fig.close()
  pyplot.clf()
  pyplot.close()
  
def plotBars(
  cats_list,
  dist_list,
  userParams,
):
  global plotCount
  plotCount += 1

  colormap = pyplot.cm.gist_ncar
  colorcycle = [colormap(i) for i in np.linspace(0, 0.9, len(dist_list))]
  pyplot.gca().set_color_cycle(colorcycle) 

  #matplotlib.rcParams['axes.color_cycle'] = colorcycle
  matplotlib.axes.set_default_color_cycle(colorcycle)

  params = {
    'xlabel' : 'x value',
    'ylabel' : 'y value',
    'fname' : 'fig',
    'vertical' : False,
    'large' : False,
  }

  params.update(userParams)


  fig = pyplot.figure()

  fig.subplots_adjust(bottom=0.4)
  ax = fig.add_subplot(111)

  margin = 0.1
  #width = (1.0 - 2.0 * margin) / len(dist_list[0][0])
  width=0.2

  # keep references to the bars plotted
  barPlot_list = []
  i = 0
  names_list = []
  for (name, counter) in dist_list:
    names_list.append(name)
    val_list = []
    for cat in cats_list:
      val_list.append(counter[cat])
    ind = np.arange(len(val_list))

    if params['vertical']:
      bars = ax.barh(
        ind + margin + (width * i),
        val_list,
        height=width,
        align='center',
        color=cm.jet(1.*i/len(dist_list)),
        #edgecolor='none',
      )
    else:
      bars = ax.bar(
        ind + margin + (width * i),
        val_list,
        width=width,
        align='center',
        color=cm.jet(1.*i/len(dist_list)),
        #edgecolor='none',
      )

    barPlot_list.append( bars )
    i += 1

  ax.ticklabel_format(style='plain')

  if params['vertical']:
    ax.set_yticks(np.arange(len(cats_list)))
    ax.set_yticklabels(
      cats_list,
    )
    ax.yaxis.grid(color='gray', linestyle='dashed')
  else:
    ax.set_xticks(np.arange(len(cats_list)))
    ax.set_xticklabels(
      cats_list,
      rotation=90,
      ha='left'
    )
    ax.xaxis.grid(color='gray', linestyle='dashed')
  
  barRef_list = map(
    lambda(s): s[0],
    barPlot_list,
  )
  pyplot.legend(
    barRef_list,
    names_list,
    ncol=4, loc='upper center',
    bbox_to_anchor=[0.5, 1.1],
    columnspacing=1.0, labelspacing=0.0,
    handletextpad=0.0, handlelength=1.5,
    fancybox=True, shadow=True,
    prop={'size':6}
  )

  if params['vertical']:
    pyplot.ylabel(params['xlabel'])
    pyplot.xlabel(params['ylabel'])
  else:
    pyplot.xlabel(params['xlabel'])
    pyplot.ylabel(params['ylabel'])

  ax.yaxis.label.set_size(8)
  ax.xaxis.label.set_size(8)

  fname = '%s%d.png' % (params['fname'], plotCount)
  logger.info('saving {0}'.format(fname))
  plot_path = fname
  pyplot.savefig(plot_path, dpi = (dpiRes))
  # NOTE need these to free up memory
  fig.clf()
  #fig.close()
  pyplot.clf()
  pyplot.close()

def plotScatters(
  data_list,
  userParams,
):

  # create a new figure
  fig = pyplot.figure(plotCount)
  #fig.subplots_adjust(bottom=0.2)
  sub = fig.add_subplot(111)
  sub.spines['top'].set_visible(False)
  sub.spines['right'].set_visible(False)
  sub.xaxis.set_ticks_position('bottom')
  sub.yaxis.set_ticks_position('left')

  params = {
    'xlabel' : 'x value',
    'ylabel' : 'y value',
    'xlog' : False,
    'ylog' : False,
    'fname' : 'fig',
    'xmin'  : None,
    'ymin'  : None,
    'xmax'  : None,
    'ymax'  : None,
    'save' : 'png',
    'title' : None,
  }
  params.update(userParams)

  pyplot.xlabel(params['xlabel'])
  pyplot.ylabel(params['ylabel'])
  if params['title']:
    pyplot.title(params['title'])

  cstep = 1.0 / len(data_list)
  labels = []
  for (i, (name, points_list)) in enumerate(data_list):
    xs = map(lambda(p): p[0], points_list)
    ys = map(lambda(p): p[1], points_list)
    kwargs = {
      'alpha' : 0.5,  
      'c' : str(cstep * i),
      'marker' : 'o',
    }
    if name in params:
      kwargs['alpha']  = params[name]['alpha']
      kwargs['c']      = params[name]['color']
      kwargs['marker'] = params[name]['marker']
    sub.scatter(
      xs,
      ys,
      label=name,
      **kwargs
    )
    labels.append(name)

  #if params['xmin'] and params['xmax']:
  #  pyplot.xlim(
  #    xmin=params['xmin'],
  #    xmax=params['xmax'],
  #  )
  #elif params['xmin']:
  #  pyplot.xlim(xmin=params['xmin'])
  #elif params['xmax']:
  #  pyplot.xlim(xmax=params['xmax'])

  #if params['ymin']:
  #  pyplot.ylim(ymin=params['ymin'])
  #if params['ymax']:
  #  pyplot.ylim(ymax=params['ymax'])

  xmin, xmax, ymin, ymax = pyplot.axis()
  if params['xmin'] != None: xmin = params['xmin']
  if params['xmax'] != None: xmax = params['xmax']
  if params['ymin'] != None: ymin = params['ymin']
  if params['ymax'] != None: ymax = params['ymax']
  pyplot.axis((xmin, xmax, ymin, ymax))

  if params['xlog']:
    sub.set_xscale('log')
  if params['ylog']:
    sub.set_yscale('log')

  pyplot.legend(loc=2, labelspacing=0.5, prop={'size':10})

  if params['save'] == 'png':
    fname = '{0}_{1}.png'.format(params['fname'], plotCount)
    pyplot.savefig(fname, dpi = (dpiRes), bbox_inches='tight')
  elif params['save'] == 'svg':
    fname = '{0}_{1}.svg'.format(params['fname'], plotCount)
    pyplot.savefig(fname, bbox_inches='tight')
  else:
    die

  logger.info('saving {0}'.format(fname))
  # NOTE need these to free up memory
  fig.clf()
  #fig.close()
  pyplot.clf()
  pyplot.close()

def __plotHistograms__(
  counter_list,
  userParams,
):

  # create a new figure
  fig = pyplot.figure(
    plotCount,
    figsize=(4,4),
  )
  #fig.subplots_adjust(bottom=0.2)
  sub = fig.add_subplot(111)
  sub.spines['top'].set_visible(False)
  sub.spines['right'].set_visible(False)
  sub.xaxis.set_ticks_position('bottom')
  sub.yaxis.set_ticks_position('left')

  params = {
    'xlabel' : 'x value',
    'pct' : False,
    'ylabel' : 'freq',
    'fname' : 'fig',
    'xmin'  : None,
    'xlog': False,
    'ylog': False,
    'numBins' : 10,
    'save' : 'png',
    'normed' : False,
    'title' : None,
    'rescale' : False,
  }

  params.update(userParams)
  numBins = params['numBins']

  pyplot.xlabel(params['xlabel'])
  pyplot.ylabel(params['ylabel'])
  if params['title']:
    pyplot.title(params['title'])

  histVals = Counter()

  def getValIterFromCounter(counter):

    vals_list = []
    for (val, freq) in counter.items():
      vals_list.extend(freq * [val])
    return vals_list

  # determine bins
  mins = []
  maxs = []
  for (name, counter) in counter_list:
    # skip over empty
    if len(counter) == 0:
      continue

    if type(counter) in [type({}), type(Counter())]:
      mins.append(min(counter.keys()))
      maxs.append(max(counter.keys()))
    else:
      val_list = counter
      mins.append(min(val_list))
      maxs.append(max(val_list))

  if params['xlog']:
    bins = np.logspace(np.log10(min(mins)), np.log10(max(maxs)), numBins)
  else:
    bins = np.linspace(min(mins), max(maxs), numBins)

  labels = []
  for (name, counter) in counter_list:
    # skip over empty
    if len(counter) == 0:
      continue

    if type(counter) in [type({}), type(Counter())]:
      vals = getValIterFromCounter( counter )
    else:
      vals = counter

    #if params['xlog']:
    #  vals = np.log10(vals)

    kwargs = {
      'alpha' : 0.5,  
    }
    if name in params:
      kwargs['alpha'] = params[name]['alpha']
      kwargs['color'] = params[name]['color']
    sub.hist(
      vals,
      bins,
      log=params['ylog'],
      normed=params['normed'],
      linewidth=1,
      label=name,
      histtype='bar',
      **kwargs
    )
    labels.append(name)

  if params['xlog']:
    #pyplot.xticks(bins, ['10^%s' % i for i in bins])
    pyplot.gca().set_xscale('log') 

  if params['xmin']:
    pyplot.xlim(xmin=params['xmin'])
  if params['pct']:
    def to_percent(y, position):
      s = str(int(100 * y))
      # The percent symbol needs escaping in latex
      if matplotlib.rcParams['text.usetex'] == True:
        return s + r'$\%$'
      else:
        return s + '%'

    formatter = FuncFormatter(to_percent)
    pyplot.gca().xaxis.set_major_formatter(formatter)
    #pyplot.ticklabel_format(style='sci', axis='y', scilimits=(0,0))
    #labels = sub.get_yticks() / 10 ** 3
    #labels = map(lambda(x): int(x), labels)
    #sub.set_yticklabels(labels)

  pyplot.legend(loc=2, labelspacing=0.5, prop={'size':10})

  for item in ([sub.xaxis.label, sub.yaxis.label] + sub.get_xticklabels() + sub.get_yticklabels()):
    item.set_fontsize(10)

  if params['save'] == 'png':
    fname = '{0}_{1}.png'.format(params['fname'], plotCount)
    pyplot.savefig(fname, dpi = (dpiRes), bbox_inches='tight')
  elif params['save'] == 'svg':
    fname = '{0}_{1}.svg'.format(params['fname'], plotCount)
    pyplot.savefig(fname, bbox_inches='tight')
  else:
    die

  logger.info('saving {0}'.format(fname))
  # NOTE need these to free up memory
  fig.clf()
  #fig.close()
  pyplot.clf()
  pyplot.close()

def plot2DKDE(xs, ys, kde, userParams_map):
  params = {
    'xlabel' : 'x',
    'ylabel' : 'y',
    'fname' : 'fig',
    'title' : None,
  }
  params.update(userParams_map)

  data = np.vstack([xs, ys])
  xmin = xs.min()
  xmax = xs.max()
  ymin = ys.min()
  ymax = ys.max()

  X, Y = np.mgrid[xmin:xmax:100j, ymin:ymax:100j]
  grid = np.vstack([X.ravel(), Y.ravel()])

  scores = kde.pdf_samples(grid)
  Z = np.reshape(scores.T, X.shape)
  
  fig = pyplot.figure()
  ax = fig.add_subplot(111)
  ax.imshow(
    np.rot90(Z),
    cmap=pyplot.cm.gist_earth_r,
    extent=[xmin, xmax, ymin, ymax],
    aspect='auto',
  )
  ax.plot(xs, ys, 'k.', markersize=2)
  ax.set_xlim([xmin, xmax])
  ax.set_ylim([ymin, ymax])
  #pyplot.gca().set_aspect('equal', adjustable='box')

  pyplot.xlabel(params['xlabel'])
  pyplot.ylabel(params['ylabel'])
  if params['title']:
    pyplot.title(params['title'])
  
  fname = params['fname']
  fname = '{0}.png'.format(fname)
  pyplot.savefig(fname, dpi=200)
  fig.clf()
  pyplot.clf()
  pyplot.close()

def plot1DKDE(xs, kde, userParams_map):
  data = xs
  xmin = xs.min()
  xmax = xs.max()

  grid = np.linspace(xmin, xmax, 1000)
  scores = kde.pdf_samples(grid)

  fig = pyplot.figure()
  ax = fig.add_subplot(111)
  # estimated density
  ax.plot(grid, scores)
  # histogram of true values
  ax.hist(data, 30, fc='gray', histtype='stepfilled', alpha=0.3, normed=True)
  ax.set_xlim([xmin, xmax])

  fname = userParams_map['fname']
  pyplot.xlabel(userParams_map['xlabel'])
  #pyplot.gca().set_aspect('equal', adjustable='box')
  
  fname = '{0}.png'.format(fname)
  pyplot.savefig(fname, dpi=200)
  fig.clf()
  pyplot.clf()
  pyplot.close()

