/*
 * Decompiled with CFR 0.152.
 */
package jsat.regression;

import java.util.ArrayList;
import java.util.List;
import jsat.DataSet;
import jsat.classifiers.CategoricalData;
import jsat.classifiers.DataPoint;
import jsat.classifiers.DataPointPair;
import jsat.linear.DenseVector;
import jsat.linear.IndexValue;
import jsat.linear.SparseVector;
import jsat.linear.Vec;

public class RegressionDataSet
extends DataSet<RegressionDataSet> {
    protected List<DataPointPair<Double>> dataPoints;
    private static final int[] emptyInt = new int[0];

    public RegressionDataSet(int numerical, CategoricalData[] categories) {
        this.numNumerVals = numerical;
        this.categories = categories;
        this.dataPoints = new ArrayList<DataPointPair<Double>>();
        this.numericalVariableNames = new ArrayList(this.getNumNumericalVars());
        this.setUpGenericNumericNames();
    }

    public RegressionDataSet(List<DataPoint> data, int predicting) {
        DataPoint tmp = data.get(0);
        this.categories = new CategoricalData[tmp.numCategoricalValues()];
        System.arraycopy(tmp.getCategoricalData(), 0, this.categories, 0, this.categories.length);
        this.numNumerVals = tmp.numNumericalValues() - 1;
        this.dataPoints = new ArrayList<DataPointPair<Double>>(data.size());
        for (DataPoint dp : data) {
            Vec origV = dp.getNumericalValues();
            double target = 0.0;
            Vec newVec = origV.isSparse() ? new SparseVector(origV.length() - 1, origV.nnz()) : new DenseVector(origV.length() - 1);
            for (IndexValue iv : origV) {
                if (iv.getIndex() < predicting) {
                    newVec.set(iv.getIndex(), iv.getValue());
                    continue;
                }
                if (iv.getIndex() == predicting) {
                    target = iv.getValue();
                    continue;
                }
                newVec.set(iv.getIndex() - 1, iv.getValue());
            }
            DataPoint newDp = new DataPoint(newVec, dp.getCategoricalValues(), this.categories, dp.getWeight());
            DataPointPair<Double> dpp = new DataPointPair<Double>(newDp, target);
            this.dataPoints.add(dpp);
        }
        this.numericalVariableNames = new ArrayList(this.getNumNumericalVars());
        this.setUpGenericNumericNames();
    }

    public RegressionDataSet(List<DataPointPair<Double>> list) {
        this.numNumerVals = list.get(0).getDataPoint().numNumericalValues();
        this.numericalVariableNames = new ArrayList(this.getNumNumericalVars());
        this.setUpGenericNumericNames();
        this.categories = CategoricalData.copyOf(list.get(0).getDataPoint().getCategoricalData());
        this.dataPoints = new ArrayList<DataPointPair<Double>>(list.size());
        for (DataPointPair<Double> dpp : list) {
            this.dataPoints.add(new DataPointPair<Double>(dpp.getDataPoint().clone(), dpp.getPair()));
        }
    }

    private void setUpGenericNumericNames() {
        if (this.getNumNumericalVars() > 100) {
            return;
        }
        for (int i = 0; i < this.getNumNumericalVars(); ++i) {
            this.numericalVariableNames.add("Numeric Input " + (i + 1));
        }
    }

    private RegressionDataSet() {
    }

    public static RegressionDataSet comineAllBut(List<RegressionDataSet> list, int exception) {
        int numer = list.get(exception).getNumNumericalVars();
        CategoricalData[] categories = list.get(exception).getCategories();
        RegressionDataSet rds = new RegressionDataSet(numer, categories);
        for (int i = 0; i < list.size(); ++i) {
            if (i == exception) continue;
            rds.dataPoints.addAll(list.get((int)i).dataPoints);
        }
        return rds;
    }

    public void addDataPoint(Vec numerical, double val) {
        this.addDataPoint(numerical, emptyInt, val);
    }

    public void addDataPoint(Vec numerical, int[] categories, double val) {
        if (numerical.length() != this.numNumerVals) {
            throw new RuntimeException("Data point does not contain enough numerical data points");
        }
        if (categories.length != categories.length) {
            throw new RuntimeException("Data point does not contain enough categorical data points");
        }
        for (int i = 0; i < categories.length; ++i) {
            if (this.categories[i].isValidCategory(categories[i]) || categories[i] < 0) continue;
            throw new RuntimeException("Categoriy value given is invalid");
        }
        DataPoint dp = new DataPoint(numerical, categories, this.categories);
        this.addDataPoint(dp, val);
    }

    public void addDataPoint(DataPoint dp, double val) {
        if (dp.numNumericalValues() != this.getNumNumericalVars() || dp.numCategoricalValues() != this.getNumCategoricalVars()) {
            throw new RuntimeException("The added data point does not match the number of values and categories for the data set");
        }
        if (Double.isInfinite(val) || Double.isNaN(val)) {
            throw new ArithmeticException("Unregressiable value " + val + " given for regression");
        }
        DataPointPair<Double> dpp = new DataPointPair<Double>(dp, val);
        this.dataPoints.add(dpp);
        this.columnVecCache.clear();
    }

    public void addDataPointPair(DataPointPair<Double> pair) {
        this.dataPoints.add(pair);
        this.columnVecCache.clear();
    }

    @Override
    public DataPoint getDataPoint(int i) {
        return this.dataPoints.get(i).getDataPoint();
    }

    public DataPointPair<Double> getDataPointPair(int i) {
        return this.dataPoints.get(i);
    }

    public List<DataPointPair<Double>> getAsDPPList() {
        ArrayList<DataPointPair<Double>> list = new ArrayList<DataPointPair<Double>>(this.dataPoints.size());
        for (DataPointPair<Double> dpp : this.dataPoints) {
            list.add(new DataPointPair<Double>(dpp.getDataPoint().clone(), dpp.getPair()));
        }
        return list;
    }

    public List<DataPointPair<Double>> getDPPList() {
        ArrayList<DataPointPair<Double>> list = new ArrayList<DataPointPair<Double>>(this.dataPoints);
        return list;
    }

    @Override
    public void setDataPoint(int i, DataPoint dp) {
        this.dataPoints.get(i).setDataPoint(dp);
        this.columnVecCache.clear();
    }

    public void setTargetValue(int i, double val) {
        if (Double.isInfinite(val) || Double.isNaN(val)) {
            throw new ArithmeticException("Can not predict a " + val + " value");
        }
        this.dataPoints.get(i).setPair(val);
    }

    @Override
    protected RegressionDataSet getSubset(List<Integer> indicies) {
        RegressionDataSet newData = new RegressionDataSet(this.numNumerVals, this.categories);
        for (int i : indicies) {
            newData.addDataPoint(this.getDataPoint(i), this.getTargetValue(i));
        }
        return newData;
    }

    @Override
    public int getSampleSize() {
        return this.dataPoints.size();
    }

    public Vec getTargetValues() {
        DenseVector vals = new DenseVector(this.getSampleSize());
        for (int i = 0; i < this.getSampleSize(); ++i) {
            vals.set(i, this.dataPoints.get(i).getPair());
        }
        return vals;
    }

    public double getTargetValue(int i) {
        return this.dataPoints.get(i).getPair();
    }

    public static RegressionDataSet usingDPPList(List<DataPointPair<Double>> list) {
        RegressionDataSet rds = new RegressionDataSet();
        rds.dataPoints = list;
        rds.numNumerVals = list.get(0).getDataPoint().numNumericalValues();
        rds.numericalVariableNames = new ArrayList(rds.getNumNumericalVars());
        for (int i = 0; i < rds.getNumNumericalVars(); ++i) {
            rds.numericalVariableNames.add("Numeric Input " + (i + 1));
        }
        rds.categories = CategoricalData.copyOf(list.get(0).getDataPoint().getCategoricalData());
        return rds;
    }

    public RegressionDataSet shallowClone() {
        RegressionDataSet clone = new RegressionDataSet(this.numNumerVals, this.categories);
        for (DataPointPair<Double> dpp : this.dataPoints) {
            clone.dataPoints.add(new DataPointPair<Double>(dpp.getDataPoint(), dpp.getPair()));
        }
        clone.columnVecCache.putAll(this.columnVecCache);
        return clone;
    }

    @Override
    public RegressionDataSet getTwiceShallowClone() {
        return (RegressionDataSet)super.getTwiceShallowClone();
    }
}

