/*
 * Decompiled with CFR 0.152.
 */
package jsat.classifiers.svm.extended;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import jsat.DataSet;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.Classifier;
import jsat.classifiers.DataPoint;
import jsat.classifiers.calibration.BinaryScoreClassifier;
import jsat.distributions.Distribution;
import jsat.distributions.LogUniform;
import jsat.distributions.Uniform;
import jsat.exceptions.FailedToFitException;
import jsat.linear.DenseMatrix;
import jsat.linear.DenseVector;
import jsat.linear.Matrix;
import jsat.linear.MatrixOfVecs;
import jsat.linear.ScaledVector;
import jsat.linear.Vec;
import jsat.parameters.Parameterized;
import jsat.utils.IntList;
import jsat.utils.ListUtils;

public class CPM
implements BinaryScoreClassifier,
Classifier,
Parameterized {
    private static final long serialVersionUID = 3171068484917637037L;
    private int epochs;
    private double lambda;
    private int K;
    private double entropyThreshold;
    private double h;
    private Matrix Wp;
    private Matrix Wn;
    private Vec bp;
    private Vec bn;

    public CPM() {
        this(1.0);
    }

    public CPM(int K) {
        this(1.0, K);
    }

    public CPM(double lambda) {
        this(lambda, 16);
    }

    public CPM(double lambda, int K) {
        this(lambda, K, 3.0);
    }

    public CPM(double lambda, int K, double entropyThreshold) {
        this(lambda, K, entropyThreshold, 50);
    }

    public CPM(double lambda, int K, double entropyThreshold, int epochs) {
        this.setEpochs(epochs);
        this.setLambda(lambda);
        this.setK(K);
        this.setEntropyThreshold(entropyThreshold);
    }

    public CPM(CPM toCopy) {
        this.epochs = toCopy.epochs;
        this.lambda = toCopy.lambda;
        this.K = toCopy.K;
        this.entropyThreshold = toCopy.entropyThreshold;
        this.h = toCopy.h;
        if (toCopy.Wp != null) {
            this.Wp = toCopy.Wp.clone();
        }
        if (toCopy.Wn != null) {
            this.Wn = toCopy.Wn.clone();
        }
        if (toCopy.bp != null) {
            this.bp = toCopy.bp.clone();
        }
        if (toCopy.bn != null) {
            this.bn = toCopy.bn.clone();
        }
    }

    public void setEntropyThreshold(double entropyThreshold) {
        if (entropyThreshold < 0.0 || Double.isNaN(entropyThreshold) || Double.isInfinite(entropyThreshold)) {
            throw new IllegalArgumentException("Entropy threshold must be non-negative, not " + entropyThreshold);
        }
        this.entropyThreshold = entropyThreshold;
        this.set_h_properly();
    }

    private void set_h_properly() {
        this.h = Math.log(this.entropyThreshold * (double)this.K / 10.0) / Math.log(2.0);
        if (this.h <= 0.0) {
            this.h = 0.0;
        }
    }

    public double getEntropyThreshold() {
        return this.entropyThreshold;
    }

    public void setLambda(double lambda) {
        this.lambda = lambda;
    }

    public double getLambda() {
        return this.lambda;
    }

    public void setK(int K) {
        this.K = K;
        this.set_h_properly();
    }

    public int getK() {
        return this.K;
    }

    public void setEpochs(int epochs) {
        if (epochs < 1) {
            throw new IllegalArgumentException("epochs must be a positive value");
        }
        this.epochs = epochs;
    }

    public int getEpochs() {
        return this.epochs;
    }

    @Override
    public CategoricalResults classify(DataPoint data) {
        Vec x = data.getNumericalValues();
        double pos_score = this.Wp.multiply(x).add(this.bp).max();
        double neg_score = this.Wn.multiply(x).add(this.bn).max();
        CategoricalResults cr = new CategoricalResults(2);
        if (neg_score > 0.0 && pos_score > 0.0) {
            if (neg_score > pos_score) {
                cr.setProb(0, 1.0);
            } else {
                cr.setProb(1, 1.0);
            }
        } else if (neg_score > 0.0) {
            cr.setProb(0, 1.0);
        } else if (pos_score > 0.0) {
            cr.setProb(1, 1.0);
        } else if (neg_score > pos_score) {
            cr.setProb(0, 1.0);
        } else {
            cr.setProb(1, 1.0);
        }
        return cr;
    }

    @Override
    public double getScore(DataPoint dp) {
        Vec x = dp.getNumericalValues();
        double pos_score = this.Wp.multiply(x).add(this.bp).max();
        double neg_score = this.Wn.multiply(x).add(this.bn).max();
        return pos_score - neg_score;
    }

    @Override
    public boolean supportsWeightedData() {
        return false;
    }

    private int ASSIGN(Vec dots, int indx, int k_true_max, int[] owned, int[] assignments, int assigned_positive_instances) {
        int old_owner = assignments[indx];
        double cur_entropy = 0.0;
        double new_entropy = Double.POSITIVE_INFINITY;
        int max_owned = 0;
        if (assigned_positive_instances > this.K * 10) {
            new_entropy = 0.0;
            for (int i = 0; i < this.K; ++i) {
                max_owned = Math.max(max_owned, owned[i]);
                double numer = owned[i];
                double denom = assigned_positive_instances;
                if (numer > 0.0) {
                    cur_entropy += -numer * (Math.log(numer) - Math.log(denom)) / (Math.log(2.0) * denom);
                }
                if (old_owner < 0) {
                    denom += 1.0;
                    if (i == k_true_max) {
                        numer += 1.0;
                    }
                    if (!(numer > 0.0)) continue;
                    new_entropy += -numer * (Math.log(numer) - Math.log(denom)) / (Math.log(2.0) * denom);
                    continue;
                }
                if (old_owner == k_true_max) {
                    new_entropy = cur_entropy;
                    continue;
                }
                if (i == k_true_max) {
                    numer += 1.0;
                } else if (i == old_owner) {
                    numer -= 1.0;
                }
                if (!(numer > 0.0)) continue;
                new_entropy += -numer * (Math.log(numer) - Math.log(denom)) / (Math.log(2.0) * denom);
            }
            new_entropy += cur_entropy;
        }
        if (new_entropy >= this.h) {
            return k_true_max;
        }
        int k_inc_max = 0;
        if (old_owner >= 0) {
            for (int i = 1; i < dots.length(); ++i) {
                if (owned[old_owner] <= owned[i] || !(dots.get(i) > dots.get(k_inc_max))) continue;
                k_inc_max = i;
            }
        } else {
            double best_score = Double.NEGATIVE_INFINITY;
            for (int i = 1; i < dots.length(); ++i) {
                if (max_owned <= owned[i] || !(dots.get(i) > best_score)) continue;
                k_inc_max = i;
                best_score = dots.get(i);
            }
            if (Double.isInfinite(best_score)) {
                return k_true_max;
            }
        }
        return k_inc_max;
    }

    private void sgdTrain(ClassificationDataSet D, MatrixOfVecs W, Vec b, int sign_mul, boolean parallel) {
        IntList order = new IntList(D.getSampleSize());
        ListUtils.addRange(order, 0, D.getSampleSize(), 1);
        double lambda_adj = this.lambda / (double)(D.getSampleSize() * this.epochs);
        int[] owned = new int[this.K];
        int assigned_positive_instances = 0;
        int[] assignments = new int[D.getSampleSize()];
        Arrays.fill(assignments, -1);
        DenseVector dots = new DenseVector(W.rows());
        long t = 0L;
        for (int epoch = 0; epoch < this.epochs; ++epoch) {
            Collections.shuffle(order);
            Iterator iterator = order.iterator();
            while (iterator.hasNext()) {
                int i = (Integer)iterator.next();
                double eta = 1.0 / (lambda_adj * (double)(++t));
                Vec x_i = D.getDataPoint(i).getNumericalValues();
                int y_i = (D.getDataPointCategory(i) * 2 - 1) * sign_mul;
                b.copyTo(dots);
                W.multiply(x_i, 1.0, dots);
                if (y_i == -1) {
                    for (int k = 0; k < this.K; ++k) {
                        if (!(((Vec)dots).get(k) > -1.0)) continue;
                        W.getRowView(k).mutableSubtract(eta, x_i);
                        b.increment(k, -eta);
                    }
                } else {
                    int k_true_max = 0;
                    for (int k = 1; k < ((Vec)dots).length(); ++k) {
                        if (!(((Vec)dots).get(k) > ((Vec)dots).get(k_true_max))) continue;
                        k_true_max = k;
                    }
                    if (((Vec)dots).get(k_true_max) < 1.0) {
                        int z = this.ASSIGN(dots, i, k_true_max, owned, assignments, assigned_positive_instances);
                        W.getRowView(z).mutableAdd(eta, x_i);
                        b.increment(z, eta);
                        if (assignments[i] < 0) {
                            ++assigned_positive_instances;
                        } else {
                            int n = assignments[i];
                            owned[n] = owned[n] - 1;
                        }
                        int n = z;
                        owned[n] = owned[n] + 1;
                        assignments[i] = z;
                    }
                }
                W.mutableMultiply(1.0 - 1.0 / (double)t);
                b.mutableMultiply(1.0 - 1.0 / (double)t);
            }
        }
    }

    @Override
    public void train(ClassificationDataSet dataSet, boolean parallel) {
        if (dataSet.getPredicting().getNumOfCategories() > 2) {
            throw new FailedToFitException("CPM is a binary classifier, it can not be trained on a dataset with " + dataSet.getPredicting().getNumOfCategories() + " classes");
        }
        int d = dataSet.getNumNumericalVars();
        ArrayList<Vec> Wv_p = new ArrayList<Vec>(this.K);
        ArrayList<Vec> Wv_n = new ArrayList<Vec>(this.K);
        this.bp = new DenseVector(this.K);
        this.bn = new DenseVector(this.K);
        for (int i = 0; i < this.K; ++i) {
            Wv_p.add(new ScaledVector(new DenseVector(d)));
            Wv_n.add(new ScaledVector(new DenseVector(d)));
        }
        MatrixOfVecs W_p = new MatrixOfVecs(Wv_p);
        MatrixOfVecs W_n = new MatrixOfVecs(Wv_n);
        this.sgdTrain(dataSet, W_p, this.bp, 1, parallel);
        this.sgdTrain(dataSet, W_n, this.bn, -1, parallel);
        this.Wp = new DenseMatrix(W_p);
        this.Wn = new DenseMatrix(W_n);
    }

    @Override
    public CPM clone() {
        return new CPM(this);
    }

    public static Distribution guessLambda(DataSet d) {
        return new LogUniform(0.1, 10000.0);
    }

    public static Distribution guessEntropyThreshold(DataSet d) {
        return new Uniform(0.1, 10.0);
    }
}

