/*
 * Decompiled with CFR 0.152.
 */
package jsat.classifiers.linear;

import jsat.SingleWeightVectorModel;
import jsat.classifiers.BaseUpdateableClassifier;
import jsat.classifiers.CategoricalData;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.DataPoint;
import jsat.classifiers.calibration.BinaryScoreClassifier;
import jsat.exceptions.FailedToFitException;
import jsat.exceptions.UntrainedModelException;
import jsat.linear.DenseVector;
import jsat.linear.Vec;

public class ROMMA
extends BaseUpdateableClassifier
implements BinaryScoreClassifier,
SingleWeightVectorModel {
    private static final long serialVersionUID = 8163937542627337711L;
    private boolean useBias = true;
    private boolean aggressive;
    private Vec w;
    private double bias;

    public ROMMA() {
        this(true);
    }

    public ROMMA(boolean aggressive) {
        this.setAggressive(aggressive);
    }

    protected ROMMA(ROMMA other) {
        this.aggressive = other.aggressive;
        if (other.w != null) {
            this.w = other.w;
        }
        this.bias = other.bias;
        this.useBias = other.useBias;
    }

    @Override
    public ROMMA clone() {
        return new ROMMA(this);
    }

    public void setAggressive(boolean aggressive) {
        this.aggressive = aggressive;
    }

    public boolean isAggressive() {
        return this.aggressive;
    }

    public void setUseBias(boolean useBias) {
        this.useBias = useBias;
    }

    public boolean isUseBias() {
        return this.useBias;
    }

    public Vec getWeightVec() {
        return this.w;
    }

    @Override
    public Vec getRawWeight() {
        return this.w;
    }

    @Override
    public double getBias() {
        return this.bias;
    }

    @Override
    public Vec getRawWeight(int index) {
        if (index < 1) {
            return this.getRawWeight();
        }
        throw new IndexOutOfBoundsException("Model has only 1 weight vector");
    }

    @Override
    public double getBias(int index) {
        if (index < 1) {
            return this.getBias();
        }
        throw new IndexOutOfBoundsException("Model has only 1 weight vector");
    }

    @Override
    public int numWeightsVecs() {
        return 1;
    }

    @Override
    public void setUp(CategoricalData[] categoricalAttributes, int numericAttributes, CategoricalData predicting) {
        if (numericAttributes <= 0) {
            throw new FailedToFitException("ROMMA requires numerical features");
        }
        if (predicting.getNumOfCategories() != 2) {
            throw new FailedToFitException("ROMMA only supports binary classification");
        }
        this.w = new DenseVector(numericAttributes);
        this.bias = 0.0;
    }

    @Override
    public void update(DataPoint dataPoint, int targetClass) {
        double y = targetClass * 2 - 1;
        Vec x = dataPoint.getNumericalValues();
        double wx = this.w.dot(x) + this.bias;
        double pred = y * wx;
        if (pred < 1.0) {
            double ww = this.w.dot(this.w);
            double xx = x.dot(x);
            double wwxx = ww * xx;
            if (this.aggressive && pred >= wwxx) {
                this.w.zeroOut();
                this.w.mutableAdd(y / xx, x);
                if (this.useBias) {
                    this.bias = y / xx;
                }
                return;
            }
            double denom = wwxx - wx * wx;
            double c = (wwxx - pred) / denom;
            double d = ww * (y - wx) / denom;
            this.w.mutableMultiply(c);
            this.w.mutableAdd(d, x);
            if (this.useBias) {
                this.bias = c * this.bias + d;
            }
        }
    }

    @Override
    public CategoricalResults classify(DataPoint data) {
        if (this.w == null) {
            throw new UntrainedModelException("Model has not been trained");
        }
        double wx = this.getScore(data);
        CategoricalResults cr = new CategoricalResults(2);
        if (wx < 0.0) {
            cr.setProb(0, 1.0);
        } else {
            cr.setProb(1, 1.0);
        }
        return cr;
    }

    @Override
    public double getScore(DataPoint dp) {
        return this.w.dot(dp.getNumericalValues()) + this.bias;
    }

    @Override
    public boolean supportsWeightedData() {
        return false;
    }
}

