import sys
import numpy as np
import logging
from collections import Counter

from ..common.util import loadPickle, writePickle 
from ..common.math import getKDE
from ..common import plotter
from ..common import util

logger = logging.getLogger(__name__)

#--------------------------------------------------------------------------
# cloud model
#--------------------------------------------------------------------------
class CloudModel(object):

  allBtypes = [
    # barcode on one side
    '1-bar',
    # barcode on both sides
    '2-bar',
    # no barcodes at ends
    '0-bar',
    # barcodes not at ends present (mixed cloud)
    'm-bar',
  ]
  allBtypes = [
    'bothbar',
    'pbar',
    'nobar',
  ]
  featIdx_map = {
    'numReads' : 0,
    'size' : 1,
  }

  def __init__(self):

    # barcode type frequencies
    self.btypeFreq_map = {}

    # densities of normalized features
    self.btypeF_map = {}

    # standardization params
    self.btypeMu_map = {}
    self.btypeStd_map = {}

    # saved computed values
    self.probCache_map = {}

  #--------------------------------------------------------------------------
  # scoring
  #--------------------------------------------------------------------------
  def getProb(self, cloudFeats):

    btype = CloudModel.getBtype(cloudFeats)
    size     = cloudFeats.size
    numReads = cloudFeats.numReads
    k = (btype, size, numReads) 
    if k in self.probCache_map:
      return self.probCache_map[k]

    # standardize size, numReads features
    mu  = self.btypeMu_map[btype][:2]
    std = self.btypeStd_map[btype][:2]
    feat_vec = self.__getFeatVec__(cloudFeats)
    stdfeat_vec = (feat_vec - mu) / std
    size     = stdfeat_vec[CloudModel.featIdx_map['size']]
    numReads = stdfeat_vec[CloudModel.featIdx_map['numReads']]

    # get densities conditioned on barcode type
    f_map = self.btypeF_map[btype]
    prob = (
      self.btypeFreq_map[btype] * 
      f_map['P_c (size, numReads)'].pdf([size, numReads])
    )
    # underflow case
    if prob == 0:
      prob = sys.float_info.min
    
    # save in cache
    self.probCache_map[k] = prob
    return prob

  #--------------------------------------------------------------------------
  # feature helpers
  #--------------------------------------------------------------------------
  @staticmethod
  def __getFeatVec__(cloudFeats):
    return np.array([
      cloudFeats.size,
      cloudFeats.numReads,
    ])

  @staticmethod
  def __getVidx__(k):
    assert k in set([
      'numReads',
      'size',
    ])
    return CloudModel.featIdx_map[k]

  @staticmethod
  def getBtype(cloudFeats):
    lbarValid = cloudFeats.lbarValid
    rbarValid = cloudFeats.rbarValid
    mixed = (cloudFeats.numMixedBar > 0)
    if mixed:
      return 'm-bar'
    elif lbarValid and rbarValid:
      return '2-bar'
    elif lbarValid ^ rbarValid:
      return '1-bar'
    else:
      assert not (lbarValid or rbarValid)
      return '0-bar'

  # FIXME old barcode types
  @staticmethod
  def getBtype(cloudFeats):
    lbarValid = cloudFeats.lbarValid
    rbarValid = cloudFeats.rbarValid
    mixed = (cloudFeats.numMixedBar > 0)
    if lbarValid and rbarValid:
      return 'bothbar'
    elif lbarValid ^ rbarValid:
      return 'pbar'
    else:
      assert not (lbarValid or rbarValid)
      return 'nobar'

  #--------------------------------------------------------------------------
  # training
  #--------------------------------------------------------------------------
  def train(self, cloudFeats_list):
    logger.info('training cloud model on {0} clouds'.format(
      len(cloudFeats_list)))

    maskVec_list = []
    for btype in CloudModel.allBtypes:
      mask = np.array(map(
        lambda(cf): CloudModel.getBtype(cf),
        cloudFeats_list,
      )) == btype
      maskVec_list.append((btype, mask))

    plotsDir_path = 'plots'
    util.mkdir_p(plotsDir_path)
    with util.cd(plotsDir_path):
  
      # standardize features and estimate densities
      Mall = np.array(map(lambda(f): self.__getFeatVec__(f), cloudFeats_list))
      for btype, mask_vec in maskVec_list:
        if np.count_nonzero(mask_vec) == 0:
          logger.warn('empty btype {0} during training'.format(btype))
          continue
        M = Mall[mask_vec,:]

        # standarize all features for each btype
        mu = np.mean(M, axis=0)
        std = np.std(M, axis=0)
        normM = (M - mu) / std
        M = normM

        self.btypeMu_map[btype] = mu
        self.btypeStd_map[btype] = std

        numReadsIdx  = CloudModel.featIdx_map['numReads'] 
        sizeIdx      = CloudModel.featIdx_map['size'] 

        f_map = {}

        data = np.vstack([M[:,sizeIdx], M[:,numReadsIdx]])
        f_map['P_c (size, numReads)'] = getKDE(data, bwselect=False)

        # save btype densities
        self.btypeF_map[btype] = f_map

        # save btype normalizations
        self.btypeMu_map[btype] = mu
        self.btypeStd_map[btype] = std

        # save btype freq
        self.btypeFreq_map[btype] = 1. * len(M) / len(Mall)

        fname = '{0}__size_nr_kde'.format(btype)
        
        plotter.plot2DKDE(
          M[:,sizeIdx],
          M[:,numReadsIdx],
          f_map['P_c (size, numReads)'],
          {
            'ylabel' : 'size',
            'xlabel' : 'num reads',
            'fname'  : fname,
            'title'  : '{0} kde'.format(btype),
          }
        )

  #--------------------------------------------------------------------------
  # saving and loading models from pickle
  #--------------------------------------------------------------------------
  @classmethod
  def load(cls, info_map):
    model = cls()
    model.btypeFreq_map = info_map['btypeFreq_map']
    model.btypeF_map    = info_map['btypeF_map'   ]
    model.btypeMu_map   = info_map['btypeMu_map'  ]
    model.btypeStd_map  = info_map['btypeStd_map' ]
    return model

  def getPickleData(self):
    info_map = {
      'btypeFreq_map' : self.btypeFreq_map,
      'btypeF_map'    : self.btypeF_map,
      'btypeMu_map'   : self.btypeMu_map,
      'btypeStd_map'  : self.btypeStd_map,
    }
    return info_map

