/*
 * Copyright (c) 2014 Daniel Kang
 *
 * This program is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU Lesser General Public
 * License along with this program; if not, write to the Free Software
 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
 */

#include <algorithm>
#include <string>
#include <vector>

#include "./GradArgs.h"
#include "../util/util.h"

using std::vector;


ArrayF<YALL_SIZE, 0> GradArgs::yall;
ArrayF<YALL_SIZE / KALLOC, 0> GradArgs::yall_ada;
ArrayF<YALL_SIZE, 0> GradArgs::yall_history;
ftype** GradArgs::yk = RENAME(yall.data(), KBIG, KALLOC);
ftype** GradArgs::yk_ada = RENAME(yall_ada.data(), KBIG, 1);
ftype** GradArgs::yk_history = RENAME(yall_history.data(), KBIG, KALLOC);

ArrayF<YGRAD_SIZE, 0> GradArgs::ygrad;

ArrayF<2, 0> GradArgs::x0;
ArrayF<2, 0> GradArgs::x0_ada;
ArrayF<2, 0> GradArgs::x0_hist;

ftype GradArgs::epoch = 1;
double GradArgs::eta = 0;
double GradArgs::eps = 0;

double GradArgs::penalty = 0;
double GradArgs::fv_epoch = 0;


GradArgs::GradArgs(const int64_t N, const variables_map& vm) : Loader(N, vm) {
  double t1 = my_time();

  out_dir = vm["out_dir"].as<std::string>();
  epoch = 1;

  init_x0_yall();

  fprintf(stderr, "Total setup time: %f\n", (my_time() - t1)/1000.0);
}


void GradArgs::init_x0_yall() {
  double num_n = 1;
  double reads_n = 1;

  #pragma omp parallel for reduction(+:reads_n, num_n)
  for (int64_t i = start; i < end; i++) {
    if (genome[i] < 0) {
      reads_n += counts[i];
      num_n++;
    }
  }
  x0[1] = -log(reads_n / num_n) - 1;

  double num = 1, cts = 1;
  #pragma omp parallel for reduction(+:num, cts)
  for (int64_t i = start; i < end; i++) {
    if (genome[i] >= 0) {
      num++;
      cts += counts[i];
    }
  }
  x0[0] = -log(cts / num);

  fprintf(stderr, "starting x0: %f, x0nt: %f\n", x0[1], x0[0]);
}


double GradArgs::find_eta(const int ksize) {
  assert(ksize <= KBIG);
  DEFINE_CONSTS(ksize);

  vector<ftype> kmer_ct(XSIZE_ALL / KALLOC);
  ftype **kmer_ctk = RENAME(&kmer_ct[0], ksize, 1);
  vector<ftype> ckmer(XSIZE_ALL / KALLOC);
  ftype **ckmer_k = RENAME(&ckmer[0], ksize, 1);


  const int mask = (1 << (2*ksize)) - 1;
  // Generate kmer counts
  for (int64_t i = start; i < end; i++) {
    int k1 = genome[i];
    if (k1 < 0) continue;
    kmer_ctk[ksize][k1 & mask]++;
  }

  // Find the correct eta
  const double ex0 = exp(-x0[0]);
  for (int64_t i = start; i < end; i++) {
    if (genome[i] < 0) continue;
    for (int j = 0; j < RESOL; j++)
      ckmer_k[ksize][genome[i] & mask] += counts[i + j];
  }

  double max_eta = 0;
  for (int kmer = 0; kmer < (1 << (2*ksize)); kmer++) {
    double kmer_count = kmer_ctk[ksize][kmer];
    double kmer_reads = ckmer_k[ksize][kmer];
    double eta_tmp = fabs(kmer_reads - kmer_count * ex0);
    max_eta = std::max(eta_tmp, max_eta);
  }

  return max_eta;
}


void GradArgs::update_eta(double e) {
  eta = e;
}


void GradArgs::update_eps(double e) {
  fprintf(stderr, "eps set to: %f\n", e);
  eps = e;
}


// THIS DUMPS HISTORY
void GradArgs::dump_params(const ftype fval, const int iter) {
  dump_out(out_dir.c_str(), "yall", yall_history.data(), yall_history.size(), iter);
  dump_out(out_dir.c_str(), "x0", x0_hist.data(), x0_hist.size(), iter);
  const ftype eta_tmp = eta;
  dump_out(out_dir.c_str(), "eta", &eta_tmp, 1, iter);
  dump_out(out_dir.c_str(), "heldout", &fval, 1, iter);
}


void GradArgs::load_params(const int iter) {
  load_out(out_dir.c_str(), "yall", yall.data(), yall.size(), iter);
  load_out(out_dir.c_str(), "x0", x0.data(), x0.size(), iter);
}


void GradArgs::dump_ada() {
  dump_out(out_dir.c_str(), "yada_ada", yall_ada.data(), yall_ada.size(), 0);
  dump_out(out_dir.c_str(), "x0_ada", x0_ada.data(), x0_ada.size(), 0);
}


void GradArgs::load_ada() {
  load_out(out_dir.c_str(), "yada_ada", yall_ada.data(), yall_ada.size(), 0);
  load_out(out_dir.c_str(), "x0_ada", x0_ada.data(), x0_ada.size(), 0);
}
