/*
 * Copyright (c) 2014 Daniel Kang
 *
 * This program is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU Lesser General Public
 * License along with this program; if not, write to the Free Software
 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
 */

#include <boost/program_options.hpp>

#include <string>

#include "../consts.h"
#include "./ArrayF.h"
#include "./Loader.h"

#ifndef CLASSES_GRADARGS_H_
#define CLASSES_GRADARGS_H_

using boost::program_options::variables_map;


class GradArgs : public Loader {
  std::string out_dir = "/mnt/";
  int64_t eps_offset = -1;

 protected:
  static ArrayF<YALL_SIZE, 0> yall;
  static ArrayF<YALL_SIZE / KALLOC, 0> yall_ada;
  static ArrayF<YALL_SIZE, 0> yall_history;  // stores accumulated gradients.
  static ftype **yk;
  static ftype **yk_ada;
  static ftype **yk_history;

  static ArrayF<YGRAD_SIZE, 0> ygrad;

  static ArrayF<2, 0> x0;
  static ArrayF<2, 0> x0_ada;
  static ArrayF<2, 0> x0_hist;

  static ftype epoch;
  static double eta;
  static double eps;

  static double penalty;
  static double fv_epoch;

 public:
  // Worker constructor
  GradArgs(const int64_t N, const variables_map& vm);
  // Heldout constructor.
  GradArgs(const GradArgs* other, const variables_map& vm);
  // Don't do anything
  GradArgs(const int64_t s, const int64_t e) : Loader(s, e) {}
  // Only load genome
  GradArgs(const int64_t N, const int64_t offset, const std::string& genome_fname) :
      Loader(N, offset, genome_fname) {}

  virtual ~GradArgs() {}

  GradArgs(const GradArgs&) = delete;
  GradArgs& operator=(const GradArgs&) = delete;

  void init_x0_yall();
  void scatter_data();

  double find_eta(const int ksize);

  void update_eta(double e);
  void update_eps(double e);
  void update_history();
  void clear_history();

  void dump_params(const ftype fval, const int iter);
  void load_params(const int iter);

  void dump_ada();
  void load_ada();

  double feval_explx0(ftype *buf, const int64_t s, const int64_t e, const bool pen);

  void sum_explx0_grad_worker(ftype *explx0, const int64_t s, const int64_t e);
  void sum_explx0_grad(ftype *explx0, const int64_t s, const int64_t e);
  void sum_explx0(ftype *buf, const int64_t s, const int64_t e);

  void set_eps_offset();
  double find_eps();
  double minibatch_feval(int64_t s, ftype eps_eff);

  bool params_valid();

  void slave_explx0(ftype *buf, const int64_t s, const int64_t e);
  void compute_gradient(ftype *buf, const ftype eps_eff, const int64_t s, const int64_t e);
  void minibatch(int64_t s, int iter);
  double worker(GradArgs *heldout, const int iter, const double prev_heldout);

  double heldout_feval(const bool swap_hist, const int64_t EPOCH_SIZE);
  double heldout_feval();
};

#endif  // CLASSES_GRADARGS_H_
