/*
 * Decompiled with CFR 0.152.
 */
package jsat.datatransform;

import java.util.Arrays;
import jsat.DataSet;
import jsat.classifiers.CategoricalData;
import jsat.classifiers.DataPoint;
import jsat.datatransform.DataTransform;
import jsat.distributions.Distribution;
import jsat.distributions.LogUniform;
import jsat.distributions.discrete.UniformDiscrete;
import jsat.linear.DenseVector;
import jsat.linear.Vec;

public class NumericalToHistogram
implements DataTransform {
    private static final long serialVersionUID = -2318706869393636074L;
    private int n;
    double[][] conversionArray;
    CategoricalData[] newDataArray;

    public NumericalToHistogram() {
        this(25);
    }

    public NumericalToHistogram(DataSet dataSet) {
        this(dataSet, (int)Math.ceil(Math.sqrt(dataSet.getSampleSize())));
    }

    public NumericalToHistogram(int n) {
        this.setNumberOfBins(n);
    }

    public NumericalToHistogram(DataSet dataSet, int n) {
        this(n);
        this.fit(dataSet);
    }

    @Override
    public void fit(DataSet dataSet) {
        int i;
        this.conversionArray = new double[dataSet.getNumNumericalVars()][2];
        double[] mins = new double[this.conversionArray.length];
        double[] maxs = new double[this.conversionArray.length];
        for (i = 0; i < mins.length; ++i) {
            mins[i] = Double.MAX_VALUE;
            maxs[i] = Double.MIN_VALUE;
        }
        for (i = 0; i < dataSet.getSampleSize(); ++i) {
            Vec v = dataSet.getDataPoint(i).getNumericalValues();
            for (int j = 0; j < mins.length; ++j) {
                double val = v.get(j);
                if (Double.isNaN(val)) continue;
                mins[j] = Math.min(mins[j], val);
                maxs[j] = Math.max(maxs[j], val);
            }
        }
        for (i = 0; i < this.conversionArray.length; ++i) {
            this.conversionArray[i][0] = mins[i];
            this.conversionArray[i][1] = (maxs[i] - mins[i]) / (double)this.n;
        }
        this.newDataArray = new CategoricalData[dataSet.getNumNumericalVars() + dataSet.getNumCategoricalVars()];
        for (i = 0; i < dataSet.getNumNumericalVars(); ++i) {
            this.newDataArray[i] = new CategoricalData(this.n);
        }
        System.arraycopy(dataSet.getCategories(), 0, this.newDataArray, dataSet.getNumNumericalVars(), dataSet.getNumCategoricalVars());
    }

    public void setNumberOfBins(int n) {
        if (n <= 0) {
            throw new RuntimeException("Must partition into a positive number of groups");
        }
        this.n = n;
    }

    public int getNumberOfBins() {
        return this.n;
    }

    public static Distribution guessNumberOfBins(DataSet data) {
        if (data.getSampleSize() < 20) {
            return new UniformDiscrete(2, data.getSampleSize() - 1);
        }
        if (data.getSampleSize() >= 1000000) {
            return new LogUniform(50.0, 1000.0);
        }
        int sqrt = (int)Math.sqrt(data.getSampleSize());
        return new UniformDiscrete(Math.max(sqrt / 3, 2), Math.min(sqrt * 3, data.getSampleSize() - 1));
    }

    private NumericalToHistogram(NumericalToHistogram other) {
        int i;
        this.n = other.n;
        if (other.conversionArray != null) {
            this.conversionArray = new double[other.conversionArray.length][];
            for (i = 0; i < other.conversionArray.length; ++i) {
                this.conversionArray[i] = Arrays.copyOf(other.conversionArray[i], other.conversionArray[i].length);
            }
        }
        if (other.newDataArray != null) {
            this.newDataArray = new CategoricalData[other.newDataArray.length];
            for (i = 0; i < other.newDataArray.length; ++i) {
                this.newDataArray[i] = other.newDataArray[i].clone();
            }
        }
    }

    @Override
    public DataPoint transform(DataPoint dp) {
        int[] newCatVals = new int[this.newDataArray.length];
        Vec v = dp.getNumericalValues();
        for (int i = 0; i < this.conversionArray.length; ++i) {
            double val = v.get(i) - this.conversionArray[i][0];
            if (Double.isNaN(val)) {
                newCatVals[i] = -1;
                continue;
            }
            int catVal = (int)Math.floor(val / this.conversionArray[i][1]);
            if (catVal < 0) {
                catVal = 0;
            } else if (catVal >= this.n) {
                catVal = this.n - 1;
            }
            newCatVals[i] = catVal;
        }
        System.arraycopy(dp.getCategoricalValues(), 0, newCatVals, this.conversionArray.length, dp.numCategoricalValues());
        return new DataPoint(new DenseVector(0), newCatVals, this.newDataArray);
    }

    @Override
    public NumericalToHistogram clone() {
        return new NumericalToHistogram(this);
    }
}

