/*
 * Decompiled with CFR 0.152.
 */
package jsat.math.optimization.stochastic;

import java.util.Arrays;
import jsat.linear.IndexValue;
import jsat.linear.ScaledVector;
import jsat.linear.Vec;
import jsat.math.optimization.stochastic.GradientUpdater;

public class NAdaGrad
implements GradientUpdater {
    private static final long serialVersionUID = 5138675613579751777L;
    private double[] G;
    private double[] S;
    private double N;
    private double biasG;
    private long t;

    public NAdaGrad() {
    }

    public NAdaGrad(NAdaGrad toCopy) {
        if (toCopy.G != null) {
            this.G = Arrays.copyOf(toCopy.G, toCopy.G.length);
        }
        if (toCopy.S != null) {
            this.S = Arrays.copyOf(toCopy.S, toCopy.S.length);
        }
        this.biasG = toCopy.biasG;
        this.N = toCopy.N;
        this.t = toCopy.t;
    }

    @Override
    public void update(Vec w, Vec grad, double eta) {
        this.update(w, grad, eta, 0.0, 0.0);
    }

    @Override
    public double update(Vec w, Vec grad, double eta, double bias, double biasGrad) {
        if (grad instanceof ScaledVector) {
            ++this.t;
            Vec x = ((ScaledVector)grad).getBase();
            for (IndexValue iv : x) {
                int indx = iv.getIndex();
                double abs_x_i = Math.abs(iv.getValue());
                if (abs_x_i > this.S[indx]) {
                    w.set(indx, w.get(indx) * this.S[indx] / abs_x_i);
                    this.S[indx] = abs_x_i;
                }
                this.N += abs_x_i * abs_x_i / (this.S[indx] * this.S[indx]);
            }
            double eta_roled = -eta * Math.sqrt((double)this.t / (this.N + 1.0E-6));
            for (IndexValue iv : grad) {
                int indx = iv.getIndex();
                double grad_i = iv.getValue();
                int n = indx;
                this.G[n] = this.G[n] + grad_i * grad_i;
                double g_ii = this.G[indx];
                w.increment(indx, eta_roled * grad_i / (this.S[indx] * Math.sqrt(g_ii)));
            }
            double biasUpdate = eta * biasGrad / Math.sqrt(this.biasG);
            this.biasG += biasGrad * biasGrad;
            return biasUpdate;
        }
        double eta_roled = -eta * Math.sqrt((double)(this.t + 1L) / Math.max(this.N, (double)(this.t + 1L)));
        for (IndexValue iv : grad) {
            int indx = iv.getIndex();
            double grad_i = iv.getValue();
            int n = indx;
            this.G[n] = this.G[n] + grad_i * grad_i;
            double g_ii = this.G[indx];
            w.increment(indx, eta_roled * grad_i / (Math.max(this.S[indx], 1.0) * Math.sqrt(g_ii)));
        }
        double biasUpdate = eta * biasGrad / Math.sqrt(this.biasG);
        this.biasG += biasGrad * biasGrad;
        return biasUpdate;
    }

    @Override
    public NAdaGrad clone() {
        return new NAdaGrad(this);
    }

    @Override
    public void setup(int d) {
        this.G = new double[d];
        this.S = new double[d];
        this.biasG = 1.0;
        this.t = 0L;
    }
}

