/*
 * Copyright (c) 2014 Daniel Kang
 *
 * This program is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU Lesser General Public
 * License along with this program; if not, write to the Free Software
 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
 */

#include <math.h>

#include <string>
#include <thread>

#include "./GradArgs.h"
#include "../util/util.h"

using std::thread;

GradArgs::GradArgs(const GradArgs* other, const variables_map& vm) : Loader(vm) {
  yk = other->yk;
}


double GradArgs::feval_explx0(ftype *explx0, const int64_t s, const int64_t e, const bool pen) {
  int64_t SIZE = e - s;
  const ftype *c = counts + s;
  const gtype *genome_a = genome + s;

  double feval = 0;
  #pragma omp parallel for reduction(+:feval)
  for (int64_t i = 0; i < SIZE; i++) {
    double tmp = 0;
    double rate = explx0[i];
    if (genome_a[i] < 0)
      tmp = rate - x0[1];
    else
      tmp = rate - x0[0];
    feval += c[i] * tmp - exp(tmp);
  }

  if (pen) {
    const double eta_eff = eta * SIZE / ((ftype) end-start);
    #pragma omp parallel for reduction(-:feval)
    for (size_t i = 0; i < yall.size(); i++)
      feval -= eta_eff * fabs(yall[i]);
  }

  // We return -feval since we want to cast the problem as minimization.
  return -feval;
}


#define HIST_SWAP(x, y, NORM) do { ftype tmp = (x); (x) = (y) / (NORM); (y) = tmp; } while (0);
#define SWAP(x, y) do { ftype tmp = (x); (x) = (y); (y) = tmp; } while (0);
double GradArgs::heldout_feval(const bool swap_hist, const int64_t EPOCH_SIZE) {
  // double t1 = my_time();

  if (swap_hist) {
    #pragma omp parallel for
    for (size_t j = 0; j < yall.size(); j++)
      HIST_SWAP(yall[j], yall_history[j], EPOCH_SIZE);
    for (size_t j = 0; j < x0.size(); j++)
      HIST_SWAP(x0[j], x0_hist[j], EPOCH_SIZE);
  }

  ftype *explx0 = (ftype*) calloc(ALLOC_SIZE(end - start), sizeof(ftype));
  // For flanks
  explx0 += 2*K;

  sum_explx0(explx0 + K, start + K, end - 2*K);
  double feval = feval_explx0(explx0, start, end - 2*K, false);

  if (swap_hist) {
    #pragma omp parallel for
    for (size_t j = 0; j < yall.size(); j++)
      SWAP(yall_history[j], yall[j]);
    for (size_t j = 0; j < x0.size(); j++)
      SWAP(x0_hist[j], x0[j]);
  }

  free(explx0 - 2*K);

  // fprintf(stderr, "heldout score: %f, eta: %f, x0: %f, x0nt: %f, time: %f\n",
  //         feval, eta, x0, x0nt, (my_time() - t1)/1000.0);
  return feval;
}


double GradArgs::heldout_feval() {
  return heldout_feval(false, 0);
}
