/*
 * Copyright (c) 2014 Daniel Kang
 *
 * This program is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU Lesser General Public
 * License along with this program; if not, write to the Free Software
 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
 */

#include <float.h>
#include <math.h>

#include <algorithm>
#include <array>
#include <random>
#include <thread>
#include <utility>
#include <vector>

#include "./GradArgs.h"
#include "../util/util.h"

using std::array;
using std::pair;
using std::thread;
using std::vector;

void GradArgs::sum_explx0_grad_worker(ftype *explx0, const int64_t s, const int64_t e) {
  int64_t SIZE = e - s;
  const gtype *genome_a = genome + s;
  for (int64_t i = 0; i < SIZE; i++) {
    int k1 = genome_a[i];
    if (k1 < 0) continue;

    int ind = k1 * KALLOC;
    for (int64_t j = 0; j< 2*K / RESOL; j++) {
      ftype yg = ygrad[ind + j];
      int64_t yind = i - K + j * RESOL;
      for (int64_t k = 0; k < RESOL; k++)
        explx0[yind + k] += yg;
    }
  }
}
void GradArgs::sum_explx0_grad(ftype *buf, const int64_t s, const int64_t e) {
  const int64_t SUM_SIZE = (e - s) + (3*K);
  const int64_t STEP_SIZE = SUM_SIZE / NUM_THREADS / 2;

  for (int i = 0; i < 2; i++) {
    vector<thread> t(NUM_THREADS);
    for (size_t j = 0; j < t.size(); j++) {
      int64_t s_offset = -K + STEP_SIZE * (2*j + i);
      int64_t e_offset = -K + STEP_SIZE * (2*j + i + 1);
      t[j] = thread(&GradArgs::sum_explx0_grad_worker, this, buf + s_offset,
                    s + s_offset, s + e_offset);
    }
    for (size_t j = 0; j < t.size(); j++)
      t[j].join();
  }
}
void GradArgs::sum_explx0(ftype *buf, const int64_t s, const int64_t e) {
  ascending_k(KBIG, 1, yall.data(), ygrad.data());
  sum_explx0_grad(buf, s, e);
}


void GradArgs::slave_explx0(ftype *explx0, const int64_t s, const int64_t e) {
  int64_t SIZE = e - s;
  const gtype *genome_a = genome + s;
  const ftype *c = counts + s;

  #pragma omp parallel for
  for (int64_t i = 0; i < SIZE; i++) {
    ftype rate = explx0[i];
    if (genome_a[i] < 0)
      explx0[i] = c[i] - exp(rate - x0[1]);
    else
      explx0[i] = c[i] - exp(rate - x0[0]);
  }

  for (int64_t i = -K2; i < 0; i++) explx0[i] = 0;
  for (int64_t i = SIZE; i < SIZE + K2; i++) explx0[i] = 0;
}


#define SWAP(x, y) if (d[y] < d[x]) { ftype tmp = d[x]; d[x] = d[y]; d[y] = tmp; }
static __inline__ void sort5(ftype *d) {
  SWAP(0, 1);
  SWAP(3, 4);
  SWAP(2, 4);
  SWAP(2, 3);
  SWAP(0, 3);
  SWAP(0, 2);
  SWAP(1, 4);
  SWAP(1, 3);
  SWAP(1, 2);
}
static __inline__ void sort4(ftype *d) {
  SWAP(0, 1);
  SWAP(2, 3);
  SWAP(0, 2);
  SWAP(1, 3);
  SWAP(1, 2);
}
double l1_minimizer(ftype **yk) {
  for (int k = KBIG - 1; k >= 1; k--) {
    #pragma omp parallel for if (k > 4)
    for (int kmer = 0; kmer < (1 << (2*k)); kmer++) {
      for (int j = 0; j < KALLOC; j++) {
        ftype vals[5] = {0};
        vals[0] = yk[k+1][(kmer + (0 << (2*k))) * KALLOC + j];
        vals[1] = yk[k+1][(kmer + (1 << (2*k))) * KALLOC + j];
        vals[2] = yk[k+1][(kmer + (2 << (2*k))) * KALLOC + j];
        vals[3] = yk[k+1][(kmer + (3 << (2*k))) * KALLOC + j];
        vals[4] = -yk[k][kmer * KALLOC + j];
        sort5(vals);
        const ftype median = vals[2];

        yk[k][kmer * KALLOC + j] += median;
        yk[k+1][(kmer + (0 << (2*k))) * KALLOC + j] -= median;
        yk[k+1][(kmer + (1 << (2*k))) * KALLOC + j] -= median;
        yk[k+1][(kmer + (2 << (2*k))) * KALLOC + j] -= median;
        yk[k+1][(kmer + (3 << (2*k))) * KALLOC + j] -= median;
      }
    }
  }

  ftype onemers[4];
  for (int i = 0; i < 4; i++)
    onemers[i] =  yk[1][i * KALLOC + K];
  sort4(onemers);
  const ftype median = (onemers[1] + onemers[2]) / 2;
  for (int i = 0; i < 4; i++)
    yk[1][i * KALLOC + K] -= median;

  return median;
}


void GradArgs::compute_gradient(ftype *explx0, const ftype eps_eff,
                                const int64_t s, const int64_t e) {
  ArrayF<YALL_SIZE, 0> yall_ada_tmp;
  ftype **yk_ada_tmp = RENAME(yall_ada_tmp.data(), KBIG, KALLOC);

  int64_t SIZE = e - s;
  const gtype *genome_a = genome + s;

  // Update x0
  double dx0[2], dx0_ = 0, dx0nt = 0;
  #pragma omp parallel for reduction(+:dx0_, dx0nt)
  for (int64_t i = 0; i < SIZE; i++) {
    if (genome_a[i] < 0)
      dx0_ -= explx0[i];
    else
      dx0nt -= explx0[i];
  }
  dx0[0] = dx0nt; dx0[1] = dx0_;

  for (size_t i = 0; i < x0.size(); i++) {
    x0_ada[i] += dx0[i] * dx0[i];
    x0[i] += dx0[i] * eps_eff / sqrt(x0_ada[i] + 1);
  }

  for (int64_t i = -K; i <= SIZE + K; i++) {
    ftype gtmp = 0;
    for (int64_t j = 0; j < RESOL; j++)
      gtmp += explx0[i + j];
    explx0[i] = gtmp;
  }

  // Update the other parameters.
  ftype* ygrad_tmp = yk_ada_tmp[KBIG];
  #pragma omp parallel for
  for (int64_t i = -K; i <= SIZE + K; i++) {
    int k1 = genome_a[i];
    if (k1 < 0) continue;

    int ind = k1 * KALLOC;
    for (int64_t j = 0, yind = i - K; j < 2*K/RESOL; j++, yind += RESOL)
#if defined(SAFE_THREADING)
      #pragma omp atomic
#endif
      ygrad_tmp[ind + j] += explx0[yind];
  }
  descending_k(yall_ada_tmp.data(), KBIG, 1, KALLOC);

  #pragma omp parallel for
  for (size_t i = 0; i < yall_ada.size(); i++)
    for (int64_t j = 0, ind = i * KALLOC; j < KALLOC; j++, ind++)
      yall_ada[i] += yall_ada_tmp[ind] * yall_ada_tmp[ind];

  const ftype eta_eff = eta * SIZE / ((ftype)end - start);
  #pragma omp parallel for
  for (size_t i = 0; i < yall.size(); i++) {
    const ftype step_size = eps_eff / sqrt(yall_ada[i / KALLOC] / KALLOC + 1);
    const ftype primal = yall[i] + yall_ada_tmp[i] * step_size;
    const ftype soft_tr = fabs(primal) - eta_eff * step_size;
    if (soft_tr <= 0)
      yall[i] = 0;
    else
      yall[i] = sign(primal) * soft_tr;
  }

  double x0_median = l1_minimizer(yk);
  x0[0] -= x0_median;

  // Compute the penalty
  double tmp = 0;
  #pragma omp parallel for reduction(+:tmp)
  for (size_t i = 0; i < yall.size(); i++)
    tmp += fabs(yall[i]);
  penalty += tmp;

  free(yk_ada_tmp);
}


void GradArgs::clear_history() {
  yall_history.clear();
  x0_hist.clear();
}


void GradArgs::update_history() {
  #pragma omp parallel for
  for (size_t i = 0; i < yall.size(); i++)
    yall_history[i] += yall[i];
  for (size_t i = 0; i < x0.size(); i++)
    x0_hist[i] += x0[i];
}


double GradArgs::minibatch_feval(int64_t s, ftype eps_eff) {
  int64_t e = s + MINIBATCH;
  ArrayF<ALLOC_SIZE(MINIBATCH), K2> explx0;
  ArrayF<2, 0> x0_old;

  // yall_history is unused at this point
  yall_history.copy(yall);
  x0_old.copy(x0);

  dump_ada();

  sum_explx0(explx0.data(), s, e);
  slave_explx0(explx0.data(), s, e);
  compute_gradient(explx0.data(), eps_eff, s, e);

  explx0.clear();
  load_ada();

  sum_explx0(explx0.data(), s, e);
  double feval = feval_explx0(explx0.data(), s, e, true);

  yall.swap(yall_history);
  x0.copy(x0_old);

  yall_history.clear();

  return feval;
}


void GradArgs::minibatch(int64_t s, int iter) {
  // double t1 = my_time();

  (void) iter;
  int64_t e = s + MINIBATCH;
  ArrayF<ALLOC_SIZE(MINIBATCH), K2> explx0;

  sum_explx0(explx0.data(), s, e);
  slave_explx0(explx0.data(), s, e);
  compute_gradient(explx0.data(), eps, s, e);
  update_history();

  epoch++;

  // fprintf(stderr, "minibatch time: %f, x0: %f\n", (my_time() - t1)/1000.0, x0);
}


void GradArgs::set_eps_offset() {
  if (eps_offset != -1) return;

  int64_t maxind = 0;
  double tmp = 0;
  double cmax = 0;

  for (int64_t i = 0; i < MINIBATCH; i++) tmp += counts[i];

  cmax = tmp;

  for (int64_t i = 1; i < end - MINIBATCH - 1; i++) {
    tmp -= counts[i-1];
    tmp += counts[i + MINIBATCH];
    if (tmp > cmax) {
      cmax = tmp;
      maxind = i;
    }
  }

  eps_offset = maxind;
  fprintf(stderr, "eps offset: %ld\n", eps_offset);
}


double GradArgs::find_eps() {
  set_eps_offset();

  vector<pair<double, double> > vals;
  double cur_eps = 5.0;
  double fv_past = DBL_MAX;
  double fv = DBL_MAX;

  double fv0  = minibatch_feval(eps_offset, 0);
  fprintf(stderr, "%f\n", fv0);

  while (fv_past >= fv && fv > fv0) {  // negated comparison to avoid nan issues
    fv_past = fv;
    cur_eps /= 1.5;
    fv = minibatch_feval(eps_offset, cur_eps);
    fprintf(stderr, "eps, fv: %f %f\n", cur_eps, fv);
    if (!std::isfinite(fv))
      fv = DBL_MAX;
    vals.push_back(std::make_pair(fv, cur_eps));
  }

  pair<double, double> best = *std::min_element(vals.begin(), vals.end());
  fprintf(stderr, "eps, fv: %f %f\n", best.second, best.first);
  return best.second;
}


bool GradArgs::params_valid() {
  for (size_t i = 0; i < yall.size(); i++)
    if (!std::isfinite(yall[i]))
      return false;
  for (size_t i = 0; i < x0.size(); i++)
    if (!std::isfinite(x0[i]))
      return false;
  return true;
}


double GradArgs::worker(GradArgs *heldout, const int iter, const double prev_heldout) {
  std::mt19937 gen(epoch + iter);

  update_eps(find_eps());

  const int64_t NUM_EPOCHS = (end - start - 6 * K) / MINIBATCH;
  // Number of minibatch steps actually taken.
  const int NUM_MB = std::min((int) NUM_EPOCHS, (int) 500);
  vector<double> fpast{prev_heldout, prev_heldout};

  penalty = 0;

  vector<int64_t> offsets;
  for (int i = 0; i < NUM_EPOCHS; i++) {
    int64_t offset = i * MINIBATCH + 3 * K;
    offset -= offset % 16;  // For alignment purposes.
    offsets.push_back(offset);
  }

  dump_params(fpast[0], iter);
  bool run_find_eps = true;
  int num_restarts = 0, pen_inversions = 0;
  for (int i = 0; (i < 15 || pen_inversions < 2) && i < 50; i++) {
    double t1 = my_time();
    clear_history();

    double prev_pen = penalty;
    penalty = 0;

    std::shuffle(offsets.begin(), offsets.end(), gen);
    for (int j = 0; j < NUM_MB; j++)
      minibatch(offsets[j], epoch);

    const double heldout_score = heldout->heldout_feval(true, NUM_MB);
    if (heldout_score / 10 > fpast[0] || !std::isfinite(heldout_score) || !params_valid()) {
      fprintf(stderr, "Triggering restart of inner iteration.\n");
      if (num_restarts > 4) {
        fprintf(stderr, "More than 4 restarts, killing current eta run.\n");
        break;
      }
      load_params(iter);
      update_eps(eps * 0.5);
      i--;
      num_restarts++;
      run_find_eps = false;
      continue;
    }

    fpast.insert(fpast.begin(), heldout_score);

    yall.swap(yall_history);
    x0.swap(x0_hist);

    dump_params(fpast[0], iter);

    fprintf(stderr, "iter: %d held: %f eta: %f x0nt: %f, x0: %f, penalty: %f time: %f\n",
            i, fpast[0], eta, x0[0], x0[1], penalty / NUM_MB, (my_time()-t1)/1000.0);

    double fv0  = minibatch_feval(eps_offset, 0);
    double fv_try  = minibatch_feval(eps_offset, eps * 1.5);
    if (fv_try < fv0 && run_find_eps) {
      update_eps(eps * 1.5);
    }

    pen_inversions += prev_pen > penalty;
  }

  load_params(iter);

  fprintf(stderr, "finished iteration: %d \n", iter);

  return fpast[0];
}
