/*
 * Decompiled with CFR 0.152.
 */
package jsat.regression;

import java.util.ArrayList;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import jsat.SingleWeightVectorModel;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.Classifier;
import jsat.classifiers.DataPoint;
import jsat.exceptions.FailedToFitException;
import jsat.exceptions.UntrainedModelException;
import jsat.linear.DenseVector;
import jsat.linear.SubVector;
import jsat.linear.Vec;
import jsat.math.Function;
import jsat.math.optimization.IterativelyReweightedLeastSquares;
import jsat.regression.RegressionDataSet;
import jsat.regression.Regressor;
import jsat.utils.FakeExecutor;
import jsat.utils.SystemInfo;

public class LogisticRegression
implements Classifier,
Regressor,
SingleWeightVectorModel {
    private static final long serialVersionUID = -5115807516729861730L;
    private Vec coefficents;
    private double shift;
    private double scale;
    private final Function logitFun = new Function(){

        @Override
        public double f(double ... x) {
            return LogisticRegression.this.logitReg(DenseVector.toDenseVec(x));
        }

        @Override
        public double f(Vec x) {
            return LogisticRegression.this.logitReg(x);
        }
    };
    private final Function logitFunD = new Function(){

        @Override
        public double f(double ... x) {
            return LogisticRegression.this.logitReg(DenseVector.toDenseVec(x));
        }

        @Override
        public double f(Vec x) {
            double y = LogisticRegression.this.logitReg(x);
            return y * (1.0 - y);
        }
    };

    private static double logit(double z) {
        return 1.0 / (1.0 + Math.exp(-z));
    }

    private double logitReg(Vec input) {
        double z = this.coefficents.get(0);
        for (int i = 1; i < this.coefficents.length(); ++i) {
            z += input.get(i - 1) * this.coefficents.get(i);
        }
        return LogisticRegression.logit(z);
    }

    public Vec getCoefficents() {
        return this.coefficents;
    }

    @Override
    public double regress(DataPoint data) {
        if (this.coefficents == null) {
            throw new UntrainedModelException("Model has not been trained");
        }
        return this.logitReg(data.getNumericalValues()) * this.scale + this.shift;
    }

    @Override
    public void train(RegressionDataSet dataSet, boolean parallel) {
        ArrayList<Vec> inputs = new ArrayList<Vec>(dataSet.getSampleSize());
        for (int i = 0; i < dataSet.getSampleSize(); ++i) {
            inputs.add(dataSet.getDataPoint(i).getNumericalValues());
        }
        this.coefficents = new DenseVector(dataSet.getNumNumericalVars() + 1);
        Vec targetValues = dataSet.getTargetValues();
        double minTarget = targetValues.min();
        double maxTarget = targetValues.max();
        this.shift = minTarget;
        this.scale = maxTarget - minTarget;
        targetValues.subtract(this.shift);
        targetValues.mutableDivide(this.scale);
        IterativelyReweightedLeastSquares optimizer = new IterativelyReweightedLeastSquares();
        ExecutorService threadPool = parallel ? Executors.newFixedThreadPool(SystemInfo.LogicalCores) : new FakeExecutor();
        this.coefficents = optimizer.optimize(1.0E-5, 100, this.logitFun, this.logitFunD, this.coefficents, inputs, targetValues, threadPool);
        threadPool.shutdownNow();
    }

    @Override
    public boolean supportsWeightedData() {
        return false;
    }

    @Override
    public Vec getRawWeight() {
        return new SubVector(1, this.coefficents.length() - 1, this.coefficents);
    }

    @Override
    public double getBias() {
        return this.coefficents.get(0);
    }

    @Override
    public Vec getRawWeight(int index) {
        if (index < 1) {
            return this.getRawWeight();
        }
        throw new IndexOutOfBoundsException("Model has only 1 weight vector");
    }

    @Override
    public double getBias(int index) {
        if (index < 1) {
            return this.getBias();
        }
        throw new IndexOutOfBoundsException("Model has only 1 weight vector");
    }

    @Override
    public int numWeightsVecs() {
        return 1;
    }

    @Override
    public LogisticRegression clone() {
        LogisticRegression clone = new LogisticRegression();
        if (this.coefficents != null) {
            clone.coefficents = this.coefficents.clone();
        }
        clone.scale = this.scale;
        clone.shift = this.shift;
        return clone;
    }

    @Override
    public CategoricalResults classify(DataPoint data) {
        if (this.coefficents == null) {
            throw new UntrainedModelException("Model has not yet been trained");
        }
        if (this.shift != 0.0 || this.scale != 1.0) {
            throw new UntrainedModelException("Model was trained for regression, not classifiaction");
        }
        CategoricalResults results = new CategoricalResults(2);
        results.setProb(1, this.regress(data));
        results.setProb(0, 1.0 - results.getProb(1));
        return results;
    }

    @Override
    public void train(ClassificationDataSet dataSet, boolean parallel) {
        if (dataSet.getClassSize() != 2) {
            throw new FailedToFitException("Logistic Regression works only in the case of two classes, and can not handle " + dataSet.getClassSize() + " classes");
        }
        RegressionDataSet rds = new RegressionDataSet(dataSet.getNumNumericalVars(), dataSet.getCategories());
        for (int i = 0; i < dataSet.getSampleSize(); ++i) {
            rds.addDataPoint(dataSet.getDataPoint(i), (double)dataSet.getDataPointCategory(i));
        }
        this.train(rds, parallel);
    }

    @Override
    public void train(ClassificationDataSet dataSet) {
        this.train(dataSet, false);
    }
}

