/*
 * Decompiled with CFR 0.152.
 */
package jsat.classifiers.linear.kernelized;

import java.util.Arrays;
import jsat.classifiers.BaseUpdateableClassifier;
import jsat.classifiers.CategoricalData;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.DataPoint;
import jsat.classifiers.calibration.BinaryScoreClassifier;
import jsat.distributions.kernels.KernelTrick;
import jsat.exceptions.FailedToFitException;
import jsat.linear.Vec;
import jsat.parameters.Parameter;
import jsat.parameters.Parameterized;

public class Forgetron
extends BaseUpdateableClassifier
implements BinaryScoreClassifier,
Parameterized {
    private static final long serialVersionUID = -2631315082407427077L;
    @Parameter.ParameterHolder
    private KernelTrick K;
    private Vec[] I;
    private double[] s;
    private int size;
    private int curPos;
    private int budget;
    private double U;
    private double Bconst;
    private double Q;
    private double M;
    private boolean selfTuned = true;

    public Forgetron(KernelTrick kernel, int budget) {
        this.K = kernel;
        this.setBudget(budget);
    }

    public void setSelfTurned(boolean selfTurned) {
        this.selfTuned = selfTurned;
    }

    public boolean isSelfTuned() {
        return this.selfTuned;
    }

    protected Forgetron(Forgetron toClone) {
        super(toClone);
        this.K = toClone.K.clone();
        this.budget = toClone.budget;
        this.U = toClone.U;
        this.Bconst = toClone.Bconst;
        this.Q = toClone.Q;
        this.M = toClone.M;
        this.curPos = toClone.curPos;
        this.size = toClone.size;
        if (toClone.I != null) {
            this.I = new Vec[toClone.I.length];
            for (int i = 0; i < toClone.I.length; ++i) {
                if (toClone.I[i] == null) continue;
                this.I[i] = toClone.I[i].clone();
            }
        }
        if (toClone.s != null) {
            this.s = Arrays.copyOf(toClone.s, toClone.s.length);
        }
    }

    public void setBudget(int budget) {
        this.budget = budget;
        double B = budget;
        this.U = Math.sqrt((B + 1.0) / Math.log(B + 1.0)) / 4.0;
        this.Bconst = Math.pow(B + 1.0, 1.0 / (2.0 * B + 2.0));
    }

    public int getBudget() {
        return this.budget;
    }

    public void setKernelTrick(KernelTrick K) {
        this.K = K;
    }

    public KernelTrick getKernelTrick() {
        return this.K;
    }

    @Override
    public CategoricalResults classify(DataPoint data) {
        CategoricalResults cr = new CategoricalResults(2);
        int winner = (int)((Math.signum(this.getScore(data)) + 1.0) / 2.0);
        cr.setProb(winner, 1.0);
        return cr;
    }

    @Override
    public double getScore(DataPoint dp) {
        return this.classify(dp.getNumericalValues());
    }

    private double classify(Vec x) {
        double r = 0.0;
        for (int i = 0; i < this.size; ++i) {
            r += this.s[i] * this.K.eval(this.I[i], x);
        }
        return r;
    }

    @Override
    public boolean supportsWeightedData() {
        return false;
    }

    @Override
    public Forgetron clone() {
        return new Forgetron(this);
    }

    @Override
    public void setUp(CategoricalData[] categoricalAttributes, int numericAttributes, CategoricalData predicting) {
        if (predicting.getNumOfCategories() != 2) {
            throw new FailedToFitException("Forgetron only supports binary classification");
        }
        if (numericAttributes == 0) {
            throw new FailedToFitException("Forgetron requires numeric attributes");
        }
        this.I = new Vec[this.budget];
        this.s = new double[this.budget];
        this.M = 0.0;
        this.Q = 0.0;
        this.size = 0;
        this.curPos = 0;
    }

    private double psi(double lambda, double mu) {
        return lambda * lambda + 2.0 * lambda - 2.0 * lambda * mu;
    }

    @Override
    public void update(DataPoint dataPoint, int targetClass) {
        double y_t = targetClass * 2 - 1;
        Vec x = dataPoint.getNumericalValues();
        double f_t = this.classify(x);
        if (!(y_t * f_t > 0.0)) {
            this.M += 1.0;
            if (this.selfTuned) {
                if (this.size + 1 <= this.budget) {
                    ++this.size;
                    this.I[this.curPos] = x;
                    this.s[this.curPos] = y_t;
                } else {
                    int r = this.curPos;
                    double fp_t = this.classify(this.I[r]) + y_t * this.K.eval(x, this.I[r]);
                    double s_r = Math.abs(this.s[r]);
                    double y_r = Math.signum(this.s[r]);
                    double a = s_r * s_r - 2.0 * y_r * s_r * fp_t;
                    double b = 2.0 * s_r;
                    double c = this.Q - 0.46875 * this.M;
                    double d = b * b - 4.0 * a * c;
                    double phi_t = a > 0.0 || a < 0.0 && d > 0.0 && (-b - Math.sqrt(d)) / (2.0 * a) > 1.0 ? Math.min(1.0, (-b + Math.sqrt(d)) / (2.0 * a)) : (Math.abs(a) <= 1.0E-13 ? Math.min(1.0, -c / b) : 1.0);
                    double fpp_t_r = phi_t * fp_t;
                    this.Q += this.psi(phi_t * s_r, y_r * fpp_t_r);
                    this.I[this.curPos] = x;
                    this.s[this.curPos] = y_t;
                    if (phi_t != 1.0) {
                        int i = 0;
                        while (i < this.s.length) {
                            int n = i++;
                            this.s[n] = this.s[n] * phi_t;
                        }
                    }
                }
            } else {
                double ff = 1.0;
                if (this.size > 0) {
                    for (int i = 0; i < this.size; ++i) {
                        ff += Math.pow(this.s[i], 2.0) * this.K.eval(this.I[i], this.I[i]);
                    }
                }
                double fNorm = Math.sqrt(ff);
                double phi = Math.min(this.Bconst, this.U / fNorm);
                this.I[this.curPos] = x;
                this.s[this.curPos] = y_t;
                if (this.size < this.budget) {
                    ++this.size;
                }
                int i = 0;
                while (i < this.size) {
                    int n = i++;
                    this.s[n] = this.s[n] * phi;
                }
            }
            this.curPos = (this.curPos + 1) % this.I.length;
        }
    }
}

