/*
 * Decompiled with CFR 0.152.
 */
package jsat.classifiers;

import jsat.classifiers.CategoricalResults;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.Classifier;
import jsat.classifiers.DataPoint;
import jsat.exceptions.UntrainedModelException;
import jsat.linear.Vec;
import jsat.regression.LogisticRegression;
import jsat.regression.RegressionDataSet;

public class MultinomialLogisticRegression
implements Classifier {
    private static final long serialVersionUID = -9168502043850569017L;
    private Vec[] classCoefficents;

    @Override
    public CategoricalResults classify(DataPoint data) {
        if (this.classCoefficents == null) {
            throw new UntrainedModelException("Model has not yet been trained");
        }
        CategoricalResults results = new CategoricalResults(this.classCoefficents.length + 1);
        double sum = 0.0;
        results.setProb(0, 1.0);
        Vec b = data.getNumericalValues();
        for (int i = 0; i < this.classCoefficents.length; ++i) {
            Vec coefs = this.classCoefficents[i];
            double exp = coefs.get(0);
            for (int j = 1; j < coefs.length(); ++j) {
                exp += b.get(j - 1) * coefs.get(j);
            }
            exp = Math.exp(exp);
            sum += exp;
            results.setProb(i + 1, exp);
        }
        results.divideConst(1.0 + sum);
        return results;
    }

    @Override
    public void train(ClassificationDataSet dataSet, boolean parallel) {
        LogisticRegression logit = new LogisticRegression();
        this.classCoefficents = new Vec[dataSet.getClassSize() - 1];
        for (int k = 1; k < dataSet.getClassSize(); ++k) {
            RegressionDataSet rds = new RegressionDataSet(dataSet.getNumNumericalVars(), dataSet.getCategories());
            for (int i = 0; i < dataSet.getSampleSize(); ++i) {
                rds.addDataPoint(dataSet.getDataPoint(i), dataSet.getDataPointCategory(i) == k ? 1.0 : 0.0);
            }
            logit.train(rds, parallel);
            this.classCoefficents[k - 1] = logit.getCoefficents();
        }
    }

    @Override
    public boolean supportsWeightedData() {
        return false;
    }

    @Override
    public MultinomialLogisticRegression clone() {
        MultinomialLogisticRegression clone = new MultinomialLogisticRegression();
        if (this.classCoefficents != null) {
            clone.classCoefficents = new Vec[this.classCoefficents.length];
            for (int i = 0; i < this.classCoefficents.length; ++i) {
                clone.classCoefficents[i] = this.classCoefficents[i].clone();
            }
        }
        return clone;
    }
}

