/*
 * Decompiled with CFR 0.152.
 */
package jsat.classifiers.linear.kernelized;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Random;
import jsat.DataSet;
import jsat.classifiers.BaseUpdateableClassifier;
import jsat.classifiers.CategoricalData;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.DataPoint;
import jsat.classifiers.calibration.BinaryScoreClassifier;
import jsat.distributions.Distribution;
import jsat.distributions.LogUniform;
import jsat.distributions.kernels.KernelTrick;
import jsat.linear.Vec;
import jsat.lossfunctions.LogisticLoss;
import jsat.lossfunctions.LossC;
import jsat.parameters.Parameter;
import jsat.parameters.Parameterized;
import jsat.utils.DoubleList;
import jsat.utils.random.RandomUtil;

public class OSKL
extends BaseUpdateableClassifier
implements BinaryScoreClassifier,
Parameterized {
    private static final long serialVersionUID = 4207594016856230134L;
    @Parameter.ParameterHolder
    private KernelTrick k;
    private double eta;
    private double R;
    private double G;
    private double curSqrdNorm;
    private LossC lossC;
    private boolean useAverageModel = true;
    private int t;
    private int last_t;
    private int burnIn;
    private DoubleList alphaAveraged;
    private List<Vec> vecs;
    private DoubleList alphas;
    private DoubleList inputKEvals;
    private List<Double> accelCache;
    private Random rand;

    public OSKL(KernelTrick k, double R) {
        this(k, 0.9, 1.0, R);
    }

    public OSKL(KernelTrick k, double eta, double G, double R) {
        this(k, eta, G, R, new LogisticLoss());
    }

    public OSKL(KernelTrick k, double eta, double G, double R, LossC lossC) {
        this.setKernel(k);
        this.setEta(eta);
        this.setR(R);
        this.setG(G);
        this.lossC = lossC;
    }

    public OSKL(OSKL toCopy) {
        this.k = toCopy.k.clone();
        this.eta = toCopy.eta;
        this.R = toCopy.R;
        this.G = toCopy.G;
        this.curSqrdNorm = toCopy.curSqrdNorm;
        this.lossC = toCopy.lossC.clone();
        this.t = toCopy.t;
        this.last_t = toCopy.last_t;
        this.useAverageModel = toCopy.useAverageModel;
        this.burnIn = toCopy.burnIn;
        if (toCopy.vecs != null) {
            this.vecs = new ArrayList<Vec>();
            for (Vec v : toCopy.vecs) {
                this.vecs.add(v.clone());
            }
            this.alphas = new DoubleList(toCopy.alphas);
            this.alphaAveraged = new DoubleList(toCopy.alphaAveraged);
            this.inputKEvals = new DoubleList(toCopy.inputKEvals);
        }
        if (toCopy.accelCache != null) {
            this.accelCache = new DoubleList(toCopy.accelCache);
        }
        this.rand = RandomUtil.getRandom();
    }

    public void setKernel(KernelTrick k) {
        this.k = k;
    }

    public KernelTrick getKernel() {
        return this.k;
    }

    public void setEta(double eta) {
        if (eta <= 0.0 || Double.isNaN(eta) || Double.isInfinite(eta)) {
            throw new IllegalArgumentException("Eta must be positive, not " + eta);
        }
        this.eta = eta;
    }

    public double getEta() {
        return this.eta;
    }

    public void setG(double G) {
        if (G < 1.0 || Double.isInfinite(G) || Double.isNaN(G)) {
            throw new IllegalArgumentException("G must be in [1, Infinity), not " + G);
        }
        this.G = G;
    }

    public double getG() {
        return this.G;
    }

    public static Distribution guessR(DataSet d) {
        return new LogUniform(1.0, 100000.0);
    }

    public void setR(double R) {
        if (R <= 0.0 || Double.isNaN(R) || Double.isInfinite(R)) {
            throw new IllegalArgumentException("R must be positive, not " + R);
        }
        this.R = R;
    }

    public double getR() {
        return this.R;
    }

    public void setUseAverageModel(boolean useAverageModel) {
        this.useAverageModel = useAverageModel;
    }

    public boolean isUseAverageModel() {
        return this.useAverageModel;
    }

    public void setBurnIn(int burnIn) {
        if (burnIn < 0) {
            throw new IllegalArgumentException("Burn in must be non negative, not " + burnIn);
        }
        this.burnIn = burnIn;
    }

    public int getBurnIn() {
        return this.burnIn;
    }

    @Override
    public void setUp(CategoricalData[] categoricalAttributes, int numericAttributes, CategoricalData predicting) {
        this.rand = RandomUtil.getRandom();
        this.vecs = new ArrayList<Vec>();
        this.alphas = new DoubleList();
        this.alphaAveraged = new DoubleList();
        this.t = 0;
        this.last_t = 0;
        this.inputKEvals = new DoubleList();
        this.accelCache = this.k.supportsAcceleration() ? new DoubleList() : null;
        this.curSqrdNorm = 0.0;
    }

    public int getSupportVectorCount() {
        if (this.vecs == null) {
            return 0;
        }
        return this.vecs.size();
    }

    @Override
    public void update(DataPoint dataPoint, int targetClass) {
        Vec x_t = dataPoint.getNumericalValues();
        List<Double> qi = this.k.getQueryInfo(x_t);
        double score = this.scoreSaveEval(x_t, qi);
        double y_t = targetClass * 2 - 1;
        double lossD = this.lossC.getDeriv(score, y_t);
        ++this.t;
        if (this.rand.nextDouble() > Math.abs(lossD) / this.G) {
            return;
        }
        double alpha_t = -this.eta * Math.signum(lossD) * this.G;
        this.curSqrdNorm += alpha_t * alpha_t * this.inputKEvals.getD(0);
        for (int i = 0; i < this.alphas.size(); ++i) {
            this.curSqrdNorm += 2.0 * alpha_t * this.alphas.getD(i) * this.inputKEvals.getD(i + 1);
        }
        this.alphas.add(alpha_t);
        this.vecs.add(x_t);
        if (this.accelCache != null) {
            this.accelCache.addAll(qi);
        }
        this.alphaAveraged.add(0.0);
        this.updateAverage();
        if (this.curSqrdNorm > this.R * this.R) {
            double coeff = this.R / Math.sqrt(this.curSqrdNorm);
            this.alphas.getVecView().mutableMultiply(coeff);
            this.curSqrdNorm *= coeff * coeff;
        }
    }

    private double score(Vec x, List<Double> qi) {
        DoubleList alphToUse;
        if (this.useAverageModel && this.t > this.burnIn) {
            this.updateAverage();
            alphToUse = this.alphaAveraged;
        } else {
            alphToUse = this.alphas;
        }
        return this.k.evalSum(this.vecs, this.accelCache, alphToUse.getBackingArray(), x, qi, 0, alphToUse.size());
    }

    private double scoreSaveEval(Vec x, List<Double> qi) {
        this.inputKEvals.clear();
        this.inputKEvals.add(this.k.eval(0, 0, Arrays.asList(x), qi));
        double sum = 0.0;
        for (int i = 0; i < this.alphas.size(); ++i) {
            double k_ix = this.k.eval(i, x, qi, this.vecs, this.accelCache);
            this.inputKEvals.add(k_ix);
            sum += this.alphas.getD(i) * k_ix;
        }
        return sum;
    }

    @Override
    public CategoricalResults classify(DataPoint data) {
        Vec x = data.getNumericalValues();
        return this.lossC.getClassification(this.score(x, this.k.getQueryInfo(x)));
    }

    @Override
    public boolean supportsWeightedData() {
        return false;
    }

    @Override
    public double getScore(DataPoint dp) {
        Vec x = dp.getNumericalValues();
        return this.score(x, this.k.getQueryInfo(x));
    }

    @Override
    public OSKL clone() {
        return new OSKL(this);
    }

    private void updateAverage() {
        if (this.t == this.last_t || this.t < this.burnIn) {
            return;
        }
        if (this.last_t < this.burnIn) {
            for (int i = 0; i < this.alphaAveraged.size(); ++i) {
                this.alphaAveraged.set(i, this.alphas.get(i));
            }
        }
        double w = this.t - this.last_t;
        for (int i = 0; i < this.alphaAveraged.size(); ++i) {
            double delta = this.alphas.getD(i) - this.alphaAveraged.getD(i);
            this.alphaAveraged.set(i, this.alphaAveraged.getD(i) + delta * w / (double)this.t);
        }
        this.last_t = this.t;
    }
}

