/*
 * Decompiled with CFR 0.152.
 */
package jsat.regression;

import java.util.Collections;
import jsat.SingleWeightVectorModel;
import jsat.classifiers.DataPoint;
import jsat.linear.DenseVector;
import jsat.linear.IndexValue;
import jsat.linear.Vec;
import jsat.math.decayrates.DecayRate;
import jsat.math.decayrates.NoDecay;
import jsat.parameters.Parameterized;
import jsat.regression.RegressionDataSet;
import jsat.regression.Regressor;
import jsat.utils.IntList;
import jsat.utils.ListUtils;

public class StochasticRidgeRegression
implements Regressor,
Parameterized,
SingleWeightVectorModel {
    private static final long serialVersionUID = -3462783438115627128L;
    private double lambda;
    private int epochs;
    private int batchSize;
    private double learningRate;
    private DecayRate learningDecay;
    private Vec w;
    private double bias;

    public StochasticRidgeRegression(double lambda, int epochs, int batchSize, double learningRate) {
        this(lambda, epochs, batchSize, learningRate, new NoDecay());
    }

    public StochasticRidgeRegression(double lambda, int epochs, int batchSize, double learningRate, DecayRate learningDecay) {
        this.setLambda(lambda);
        this.setEpochs(epochs);
        this.setBatchSize(batchSize);
        this.setLearningRate(learningRate);
        this.setLearningDecay(learningDecay);
    }

    public void setLambda(double lambda) {
        if (Double.isNaN(lambda) || Double.isInfinite(lambda) || lambda <= 0.0) {
            throw new IllegalArgumentException("lambda must be a positive constant, not " + lambda);
        }
        this.lambda = lambda;
    }

    public double getLambda() {
        return this.lambda;
    }

    public void setLearningRate(double learningRate) {
        this.learningRate = learningRate;
    }

    public double getLearningRate() {
        return this.learningRate;
    }

    public void setLearningDecay(DecayRate learningDecay) {
        this.learningDecay = learningDecay;
    }

    public DecayRate getLearningDecay() {
        return this.learningDecay;
    }

    public void setBatchSize(int batchSize) {
        if (batchSize <= 0) {
            throw new IllegalArgumentException("Batch size must be a positive constant, not " + batchSize);
        }
        this.batchSize = batchSize;
    }

    public int getBatchSize() {
        return this.batchSize;
    }

    public void setEpochs(int epochs) {
        if (epochs <= 0) {
            throw new IllegalArgumentException("At least one epoch must be performed, can not use " + epochs);
        }
        this.epochs = epochs;
    }

    public int getEpochs() {
        return this.epochs;
    }

    @Override
    public Vec getRawWeight() {
        return this.w;
    }

    @Override
    public double getBias() {
        return this.bias;
    }

    @Override
    public Vec getRawWeight(int index) {
        if (index < 1) {
            return this.getRawWeight();
        }
        throw new IndexOutOfBoundsException("Model has only 1 weight vector");
    }

    @Override
    public double getBias(int index) {
        if (index < 1) {
            return this.getBias();
        }
        throw new IndexOutOfBoundsException("Model has only 1 weight vector");
    }

    @Override
    public int numWeightsVecs() {
        return 1;
    }

    @Override
    public double regress(DataPoint data) {
        return this.regress(data.getNumericalValues());
    }

    private double regress(Vec data) {
        return this.w.dot(data) + this.bias;
    }

    @Override
    public void train(RegressionDataSet dataSet, boolean parallel) {
        this.train(dataSet);
    }

    @Override
    public void train(RegressionDataSet dataSet) {
        int batch = Math.min(this.batchSize, dataSet.getSampleSize());
        this.w = new DenseVector(dataSet.getNumNumericalVars());
        IntList sample = new IntList(dataSet.getSampleSize());
        ListUtils.addRange(sample, 0, dataSet.getSampleSize(), 1);
        int time = 0;
        double[] errors = new double[batch];
        int sparse = 0;
        for (int i = 0; i < dataSet.getSampleSize(); ++i) {
            if (!dataSet.getDataPoint(i).getNumericalValues().isSparse()) continue;
            ++sparse;
        }
        boolean sparseUpdates = sparse > dataSet.getSampleSize() / 4;
        int[] lastTime = sparseUpdates ? new int[this.w.length()] : null;
        for (int epoch = 0; epoch < this.epochs; ++epoch) {
            Collections.shuffle(sample);
            double alpha = this.learningDecay.rate(epoch, this.epochs, this.learningRate) / (double)batch;
            double alphaReg = alpha * this.lambda;
            for (int i = 0; i < sample.size(); i += batch) {
                int b;
                if (i + batch >= sample.size()) continue;
                ++time;
                for (b = i; b < i + batch; ++b) {
                    errors[b - i] = this.regress(dataSet.getDataPoint(sample.get(i))) - dataSet.getTargetValue(sample.get(i));
                }
                for (b = i; b < i + batch; ++b) {
                    double error = errors[b - i];
                    double alphaError = alpha * error;
                    this.bias -= alphaError;
                    Vec x = dataSet.getDataPoint(sample.get(i)).getNumericalValues();
                    if (sparseUpdates) {
                        for (IndexValue iv : x) {
                            int idx = iv.getIndex();
                            if (lastTime[idx] != time) {
                                double theta_idx = this.w.get(idx);
                                this.w.set(idx, theta_idx * Math.pow(1.0 - alphaReg, time - lastTime[idx]));
                                lastTime[idx] = time;
                            }
                            this.w.increment(idx, -alphaError * iv.getValue());
                        }
                        continue;
                    }
                    if (b == i) {
                        this.w.mutableMultiply(1.0 - alphaReg);
                    }
                    this.w.mutableSubtract(alphaError, x);
                }
            }
            if (!sparseUpdates || this.learningDecay instanceof NoDecay && epoch != this.epochs - 1) continue;
            for (int idx = 0; idx < this.w.length(); ++idx) {
                if (lastTime[idx] == time) continue;
                double theta_idx = this.w.get(idx);
                this.w.set(idx, theta_idx * Math.pow(1.0 - alphaReg, time - lastTime[idx]));
                lastTime[idx] = time;
            }
        }
    }

    @Override
    public boolean supportsWeightedData() {
        return false;
    }

    @Override
    public StochasticRidgeRegression clone() {
        StochasticRidgeRegression clone = new StochasticRidgeRegression(this.lambda, this.epochs, this.batchSize, this.learningRate, this.learningDecay);
        if (this.w != null) {
            clone.w = this.w.clone();
        }
        clone.bias = this.bias;
        return clone;
    }
}

