/*
 * Decompiled with CFR 0.152.
 */
package jsat.math.optimization.stochastic;

import jsat.linear.DenseVector;
import jsat.linear.IndexValue;
import jsat.linear.ScaledVector;
import jsat.linear.Vec;
import jsat.math.optimization.stochastic.GradientUpdater;

public class RMSProp
implements GradientUpdater {
    private static final long serialVersionUID = 3512851084092042727L;
    private double rho;
    private Vec daigG;
    private double biasG;

    public RMSProp() {
        this(0.9);
    }

    public RMSProp(double rho) {
        this.setRho(rho);
    }

    public void setRho(double rho) {
        if (rho <= 0.0 || rho >= 1.0 || Double.isNaN(rho)) {
            throw new IllegalArgumentException("Rho should be a value in (0, 1) not " + rho);
        }
        this.rho = rho;
    }

    public double getRho() {
        return this.rho;
    }

    public RMSProp(RMSProp toCopy) {
        if (toCopy.daigG != null) {
            this.daigG = toCopy.daigG.clone();
        }
        this.rho = toCopy.rho;
        this.biasG = toCopy.biasG;
    }

    @Override
    public void update(Vec x, Vec grad, double eta) {
        this.update(x, grad, eta, 0.0, 0.0);
    }

    @Override
    public double update(Vec x, Vec grad, double eta, double bias, double biasGrad) {
        this.daigG.mutableMultiply(this.rho);
        for (IndexValue iv : grad) {
            int indx = iv.getIndex();
            double grad_i = iv.getValue();
            this.daigG.increment(indx, (1.0 - this.rho) * grad_i * grad_i);
            double g_iiRoot = Math.max(Math.sqrt(this.daigG.get(indx)), Math.abs(grad_i));
            x.increment(indx, -eta * grad_i / g_iiRoot);
        }
        this.biasG *= this.rho;
        this.biasG += (1.0 - this.rho) * biasGrad * biasGrad;
        double g_iiRoot = Math.max(Math.sqrt(this.biasG), Math.abs(biasGrad));
        return eta * biasGrad / g_iiRoot;
    }

    @Override
    public RMSProp clone() {
        return new RMSProp(this);
    }

    @Override
    public void setup(int d) {
        this.daigG = new ScaledVector(new DenseVector(d));
        this.biasG = 0.0;
    }
}

