/*
 * Decompiled with CFR 0.152.
 */
package jsat.classifiers.linear;

import java.util.Arrays;
import jsat.SingleWeightVectorModel;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.Classifier;
import jsat.classifiers.DataPoint;
import jsat.exceptions.FailedToFitException;
import jsat.linear.DenseVector;
import jsat.linear.IndexValue;
import jsat.linear.Vec;
import jsat.lossfunctions.LogisticLoss;
import jsat.parameters.Parameterized;

public class BBR
implements Classifier,
Parameterized,
SingleWeightVectorModel {
    private static final long serialVersionUID = 8297213093357011082L;
    private Vec w;
    private int maxIterations;
    private double regularization;
    private boolean autoSetRegularization = true;
    private double bias;
    private boolean useBias = true;
    private double tolerance = 5.0E-4;
    private Prior prior;

    public BBR(double regularization, int maxIterations) {
        this(regularization, maxIterations, Prior.LAPLACE);
    }

    public BBR(double regularization, int maxIterations, Prior prior) {
        this.setMaxIterations(maxIterations);
        this.setRegularization(regularization);
        this.setAutoSetRegularization(false);
        this.setPrior(prior);
    }

    public BBR(int maxIterations) {
        this(0.001, maxIterations, Prior.LAPLACE);
    }

    public BBR(int maxIterations, Prior prior) {
        this.setMaxIterations(maxIterations);
        this.setRegularization(0.01);
        this.setAutoSetRegularization(true);
        this.setPrior(prior);
    }

    protected BBR(BBR toCopy) {
        if (toCopy.w != null) {
            this.w = toCopy.w.clone();
        }
        this.maxIterations = toCopy.maxIterations;
        this.regularization = toCopy.regularization;
        this.autoSetRegularization = toCopy.autoSetRegularization;
        this.bias = toCopy.bias;
        this.useBias = toCopy.useBias;
        this.tolerance = toCopy.tolerance;
        this.prior = toCopy.prior;
    }

    public void setRegularization(double regularization) {
        if (Double.isNaN(regularization) || Double.isNaN(regularization) || regularization <= 0.0) {
            throw new IllegalArgumentException("Regularization must be positive, not " + regularization);
        }
        this.regularization = regularization;
    }

    public double getRegularization() {
        return this.regularization;
    }

    public void setAutoSetRegularization(boolean autoSetRegularization) {
        this.autoSetRegularization = autoSetRegularization;
    }

    public boolean isAutoSetRegularization() {
        return this.autoSetRegularization;
    }

    public void setMaxIterations(int maxIterations) {
        this.maxIterations = maxIterations;
    }

    public int getMaxIterations() {
        return this.maxIterations;
    }

    public void setTolerance(double tolerance) {
        if (Double.isNaN(tolerance) || Double.isInfinite(tolerance) || tolerance <= 0.0) {
            throw new IllegalArgumentException("Tolerance must be positive, not " + tolerance);
        }
        this.tolerance = tolerance;
    }

    public double getTolerance() {
        return this.tolerance;
    }

    public void setUseBias(boolean useBias) {
        this.useBias = useBias;
    }

    public boolean isUseBias() {
        return this.useBias;
    }

    public void setPrior(Prior prior) {
        this.prior = prior;
    }

    public Prior getPrior() {
        return this.prior;
    }

    public Vec getWeightVec() {
        return this.w;
    }

    @Override
    public Vec getRawWeight() {
        return this.w;
    }

    @Override
    public double getBias() {
        return this.bias;
    }

    @Override
    public Vec getRawWeight(int index) {
        if (index < 1) {
            return this.getRawWeight();
        }
        throw new IndexOutOfBoundsException("Model has only 1 weight vector");
    }

    @Override
    public double getBias(int index) {
        if (index < 1) {
            return this.getBias();
        }
        throw new IndexOutOfBoundsException("Model has only 1 weight vector");
    }

    @Override
    public int numWeightsVecs() {
        return 1;
    }

    @Override
    public CategoricalResults classify(DataPoint data) {
        return LogisticLoss.classify(this.w.dot(data.getNumericalValues()) + this.bias);
    }

    @Override
    public void train(ClassificationDataSet dataSet, boolean parallel) {
        this.train(dataSet);
    }

    @Override
    public void train(ClassificationDataSet dataSet) {
        double lambda;
        int D = dataSet.getNumNumericalVars();
        if (D <= 0) {
            throw new FailedToFitException("Data set contains no numeric features");
        }
        Vec[] columnMajor = dataSet.getNumericColumns();
        this.w = new DenseVector(D);
        double[] delta = new double[this.useBias ? D + 1 : D];
        Arrays.fill(delta, 1.0);
        int N = dataSet.getSampleSize();
        double[] r = new double[N];
        double[] y = new double[N];
        for (int i = 0; i < N; ++i) {
            y[i] = dataSet.getDataPointCategory(i) * 2 - 1;
        }
        if (this.autoSetRegularization) {
            double normSqrdSum = 0.0;
            for (int i = 0; i < N; ++i) {
                normSqrdSum += Math.pow(dataSet.getDataPoint(i).getNumericalValues().pNorm(2.0), 2.0);
            }
            double sigma = (double)(D * N) / normSqrdSum;
            lambda = this.prior == Prior.LAPLACE ? Math.max(Math.sqrt(2.0) / sigma, 1.0E-15) : Math.max(sigma * sigma, 1.0E-15);
        } else {
            lambda = this.regularization;
        }
        double[] r_change = new double[N];
        for (int iter = 0; iter < this.maxIterations; ++iter) {
            int i;
            for (int j = 0; j < D; ++j) {
                double delta_vj = 0.0;
                double w_jOrig = this.w.get(j);
                if (this.prior == Prior.LAPLACE) {
                    if (w_jOrig == 0.0) {
                        delta_vj = this.tenativeUpdate(columnMajor, j, w_jOrig, y, r, lambda, 1.0, delta);
                        if (delta_vj <= 0.0 && (delta_vj = this.tenativeUpdate(columnMajor, j, w_jOrig, y, r, lambda, -1.0, delta)) >= 0.0) {
                            delta_vj = 0.0;
                        }
                    } else {
                        double sign = Math.signum(w_jOrig);
                        if (sign * (w_jOrig + (delta_vj = this.tenativeUpdate(columnMajor, j, w_jOrig, y, r, lambda, sign, delta))) < 0.0) {
                            delta_vj = -w_jOrig;
                        }
                    }
                } else {
                    delta_vj = this.tenativeUpdate(columnMajor, j, w_jOrig, y, r, lambda, 0.0, delta);
                }
                double delta_wj = Math.min(Math.max(delta_vj, -delta[j]), delta[j]);
                for (IndexValue iv : columnMajor[j]) {
                    int i2 = iv.getIndex();
                    double change = delta_wj * iv.getValue() * y[i2];
                    int n = i2;
                    r[n] = r[n] + change;
                    int n2 = i2;
                    r_change[n2] = r_change[n2] + change;
                }
                double newW_j = w_jOrig + delta_wj;
                if (Math.abs(newW_j) < 1.0E-15) {
                    newW_j = 0.0;
                }
                this.w.set(j, newW_j);
                delta[j] = Math.max(2.0 * Math.abs(delta_wj), delta[j] / 2.0);
            }
            if (this.useBias) {
                double delta_vj;
                if (this.bias == 0.0) {
                    delta_vj = this.tenativeUpdate(null, D, this.bias, y, r, lambda, 1.0, delta);
                    if (delta_vj <= 0.0 && (delta_vj = this.tenativeUpdate(null, D, this.bias, y, r, lambda, -1.0, delta)) >= 0.0) {
                        delta_vj = 0.0;
                    }
                } else {
                    double sign = Math.signum(this.bias);
                    if (sign * (this.bias + (delta_vj = this.tenativeUpdate(null, D, this.bias, y, r, lambda, sign, delta))) < 0.0) {
                        delta_vj = -this.bias;
                    }
                }
                double delta_wj = Math.min(Math.max(delta_vj, -delta[D]), delta[D]);
                i = 0;
                while (i < N) {
                    double change = delta_wj * y[i];
                    int n = i;
                    r[n] = r[n] + change;
                    int n3 = i++;
                    r_change[n3] = r_change[n3] + change;
                }
                double newW_j = this.bias + delta_wj;
                if (Math.abs(newW_j) < 1.0E-15) {
                    newW_j = 0.0;
                }
                this.bias = newW_j;
                delta[D] = Math.max(2.0 * Math.abs(delta_wj), delta[D] / 2.0);
            }
            double changeSum = 0.0;
            double rSum = 0.0;
            for (i = 0; i < N; ++i) {
                changeSum += Math.abs(r_change[i]);
                rSum += Math.abs(r[i]);
            }
            if (changeSum / (1.0 + rSum) <= this.tolerance) break;
            Arrays.fill(r_change, 0.0);
        }
    }

    private static double F(double r, double delta) {
        if (Math.abs(r) <= delta) {
            return 0.25;
        }
        return 1.0 / (2.0 + Math.exp(Math.abs(r) - delta) + Math.exp(delta - Math.abs(r)));
    }

    @Override
    public boolean supportsWeightedData() {
        return false;
    }

    @Override
    public BBR clone() {
        return new BBR(this);
    }

    private double tenativeUpdate(Vec[] columnMajor, int j, double w_j, double[] y, double[] r, double lambda, double s, double[] delta) {
        double numer = 0.0;
        double denom = 0.0;
        if (columnMajor != null) {
            Vec col_j = columnMajor[j];
            if (col_j.nnz() == 0) {
                return 0.0;
            }
            for (IndexValue iv : col_j) {
                double x_ij = iv.getValue();
                int i = iv.getIndex();
                numer += x_ij * y[i] / (1.0 + Math.exp(r[i]));
                denom += x_ij * x_ij * BBR.F(r[i], delta[j] * Math.abs(x_ij));
                if (this.prior == Prior.LAPLACE) {
                    numer -= lambda * s;
                    continue;
                }
                numer -= w_j / lambda;
                denom += 1.0 / lambda;
            }
        } else {
            for (int i = 0; i < y.length; ++i) {
                numer += y[i] / (1.0 + Math.exp(r[i])) - lambda * s;
                denom += BBR.F(r[i], delta[j]);
            }
        }
        return numer / denom;
    }

    public static enum Prior {
        LAPLACE,
        GAUSSIAN;

    }
}

