/*
 * Decompiled with CFR 0.152.
 */
package jsat.classifiers.linear;

import jsat.SingleWeightVectorModel;
import jsat.classifiers.BaseUpdateableClassifier;
import jsat.classifiers.CategoricalData;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.DataPoint;
import jsat.classifiers.calibration.BinaryScoreClassifier;
import jsat.exceptions.FailedToFitException;
import jsat.exceptions.UntrainedModelException;
import jsat.linear.DenseVector;
import jsat.linear.Vec;

public class ALMA2
extends BaseUpdateableClassifier
implements BinaryScoreClassifier,
SingleWeightVectorModel {
    private static final long serialVersionUID = -4347891273721908507L;
    private Vec w;
    private static final double p = 2.0;
    private double alpha;
    private double B;
    private double C = Math.sqrt(2.0);
    private int k;
    private boolean useBias = true;
    private double bias;

    public ALMA2() {
        this(0.8);
    }

    public ALMA2(double alpha) {
        this.setAlpha(alpha);
    }

    protected ALMA2(ALMA2 other) {
        if (other.w != null) {
            this.w = other.w.clone();
        }
        this.alpha = other.alpha;
        this.B = other.B;
        this.C = other.C;
        this.k = other.k;
    }

    public Vec getWeightVec() {
        return this.w;
    }

    public void setAlpha(double alpha) {
        if (alpha <= 0.0 || alpha > 1.0 || Double.isNaN(alpha)) {
            throw new ArithmeticException("alpha must be in (0, 1], not " + alpha);
        }
        this.alpha = alpha;
        this.setB(1.0 / alpha);
    }

    public double getAlpha() {
        return this.alpha;
    }

    public void setB(double B) {
        this.B = B;
    }

    public double getB() {
        return this.B;
    }

    public void setC(double C) {
        if (C <= 0.0 || Double.isInfinite(C) || Double.isNaN(C)) {
            throw new ArithmeticException("C must be a posative cosntant");
        }
        this.C = C;
    }

    public double getC() {
        return this.C;
    }

    public void setUseBias(boolean useBias) {
        this.useBias = useBias;
    }

    public boolean isUseBias() {
        return this.useBias;
    }

    @Override
    public ALMA2 clone() {
        return new ALMA2(this);
    }

    @Override
    public void setUp(CategoricalData[] categoricalAttributes, int numericAttributes, CategoricalData predicting) {
        if (numericAttributes <= 0) {
            throw new FailedToFitException("ALMA2 requires numeric features");
        }
        if (predicting.getNumOfCategories() != 2) {
            throw new FailedToFitException("ALMA2 works only for binary classification");
        }
        this.w = new DenseVector(numericAttributes);
        this.k = 1;
    }

    @Override
    public void update(DataPoint dataPoint, int targetClass) {
        Vec x_t = dataPoint.getNumericalValues();
        double y_t = targetClass * 2 - 1;
        double gamma = this.B * Math.sqrt(1.0) / (double)this.k;
        double wx = this.w.dot(x_t) + this.bias;
        if (y_t * wx <= (1.0 - this.alpha) * gamma) {
            double norm;
            double eta = this.C / Math.sqrt(1.0) / Math.sqrt(this.k++);
            this.w.mutableAdd(eta * y_t, x_t);
            if (this.useBias) {
                this.bias += eta * y_t;
            }
            if ((norm = this.w.pNorm(2.0) + this.bias) > 1.0) {
                this.w.mutableDivide(norm);
            }
        }
    }

    @Override
    public CategoricalResults classify(DataPoint data) {
        if (this.w == null) {
            throw new UntrainedModelException("The model has not yet been trained");
        }
        double wx = this.getScore(data);
        CategoricalResults cr = new CategoricalResults(2);
        if (wx < 0.0) {
            cr.setProb(0, 1.0);
        } else {
            cr.setProb(1, 1.0);
        }
        return cr;
    }

    @Override
    public double getScore(DataPoint dp) {
        return this.w.dot(dp.getNumericalValues());
    }

    @Override
    public boolean supportsWeightedData() {
        return false;
    }

    @Override
    public Vec getRawWeight() {
        return this.w;
    }

    @Override
    public double getBias() {
        return this.bias;
    }

    @Override
    public Vec getRawWeight(int index) {
        if (index < 1) {
            return this.getRawWeight();
        }
        throw new IndexOutOfBoundsException("Model has only 1 weight vector");
    }

    @Override
    public double getBias(int index) {
        if (index < 1) {
            return this.getBias();
        }
        throw new IndexOutOfBoundsException("Model has only 1 weight vector");
    }

    @Override
    public int numWeightsVecs() {
        return 1;
    }
}

