/*
 * Decompiled with CFR 0.152.
 */
package jsat.distributions.multivariate;

import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import java.util.concurrent.ExecutorService;
import jsat.classifiers.DataPoint;
import jsat.distributions.empirical.kernelfunc.EpanechnikovKF;
import jsat.distributions.empirical.kernelfunc.KernelFunction;
import jsat.distributions.multivariate.MultivariateKDE;
import jsat.exceptions.UntrainedModelException;
import jsat.linear.Vec;
import jsat.linear.VecPaired;
import jsat.linear.distancemetrics.DistanceMetric;
import jsat.linear.distancemetrics.EuclideanDistance;
import jsat.linear.distancemetrics.TrainableDistanceMetric;
import jsat.linear.vectorcollection.DefaultVectorCollection;
import jsat.linear.vectorcollection.VectorCollection;
import jsat.linear.vectorcollection.VectorCollectionUtils;
import jsat.math.OnLineStatistics;
import jsat.parameters.Parameterized;

public class MetricKDE
extends MultivariateKDE
implements Parameterized {
    private static final long serialVersionUID = -2084039950938740815L;
    private KernelFunction kf;
    private double bandwidth;
    private DistanceMetric distanceMetric;
    private VectorCollection<VecPaired<Vec, Integer>> vc;
    private int defaultK;
    private double defaultStndDev;
    private static final VectorCollection<VecPaired<Vec, Integer>> defaultVC = new DefaultVectorCollection<VecPaired<Vec, Integer>>();
    public static final int DEFAULT_K = 3;
    public static final double DEFAULT_STND_DEV = 2.0;
    public static final KernelFunction DEFAULT_KF = EpanechnikovKF.getInstance();

    public MetricKDE() {
        this(DEFAULT_KF, new EuclideanDistance(), defaultVC);
    }

    public MetricKDE(DistanceMetric distanceMetric) {
        this(DEFAULT_KF, distanceMetric, defaultVC);
    }

    public MetricKDE(DistanceMetric distanceMetric, VectorCollection<VecPaired<Vec, Integer>> vc) {
        this(DEFAULT_KF, distanceMetric, vc);
    }

    public MetricKDE(KernelFunction kf, DistanceMetric distanceMetric) {
        this(kf, distanceMetric, new DefaultVectorCollection<VecPaired<Vec, Integer>>());
    }

    public MetricKDE(KernelFunction kf, DistanceMetric distanceMetric, VectorCollection<VecPaired<Vec, Integer>> vc) {
        this(kf, distanceMetric, vc, 3, 2.0);
    }

    public MetricKDE(KernelFunction kf, DistanceMetric distanceMetric, VectorCollection<VecPaired<Vec, Integer>> vc, int defaultK, double defaultStndDev) {
        this.setKernelFunction(kf);
        this.distanceMetric = distanceMetric;
        this.vc = vc;
        this.setDefaultK(defaultK);
        this.setDefaultStndDev(defaultStndDev);
    }

    public void setBandwith(double bandwidth) {
        if (bandwidth <= 0.0 || Double.isNaN(bandwidth) || Double.isInfinite(bandwidth)) {
            throw new ArithmeticException("Invalid bandwith given, bandwith must be a positive number, not " + bandwidth);
        }
        this.bandwidth = bandwidth;
    }

    public double getBandwith() {
        return this.bandwidth;
    }

    public void setDefaultK(int defaultK) {
        if (defaultK <= 0) {
            throw new ArithmeticException("At least one neighbor must be taken into acount, " + defaultK + " is invalid");
        }
        this.defaultK = defaultK;
    }

    public int getDefaultK() {
        return this.defaultK;
    }

    public void setDefaultStndDev(double defaultStndDev) {
        if (Double.isInfinite(defaultStndDev) || Double.isNaN(defaultStndDev) || defaultStndDev <= 0.0) {
            throw new ArithmeticException("The number of standard deviations to remove must bea postive number, not " + defaultStndDev);
        }
        this.defaultStndDev = defaultStndDev;
    }

    public double getDefaultStndDev() {
        return this.defaultStndDev;
    }

    public DistanceMetric getDistanceMetric() {
        return this.distanceMetric;
    }

    public void setDistanceMetric(DistanceMetric distanceMetric) {
        this.distanceMetric = distanceMetric;
    }

    @Override
    public MetricKDE clone() {
        MetricKDE clone = new MetricKDE(this.kf, this.distanceMetric.clone(), this.vc.clone(), this.defaultK, this.defaultStndDev);
        clone.bandwidth = this.bandwidth;
        if (this.vc != null) {
            clone.vc = this.vc.clone();
        }
        return clone;
    }

    @Override
    public List<? extends VecPaired<VecPaired<Vec, Integer>, Double>> getNearby(Vec x) {
        if (this.vc == null) {
            throw new UntrainedModelException("Model has not yet been created");
        }
        List<? extends VecPaired<VecPaired<Vec, Integer>, Double>> nearBy = this.getNearbyRaw(x);
        for (VecPaired<VecPaired<Vec, Integer>, Double> vecPaired : nearBy) {
            vecPaired.setPair(this.kf.k(vecPaired.getPair()));
        }
        return nearBy;
    }

    @Override
    public List<? extends VecPaired<VecPaired<Vec, Integer>, Double>> getNearbyRaw(Vec x) {
        if (this.vc == null) {
            throw new UntrainedModelException("Model has not yet been created");
        }
        List<VecPaired<VecPaired<Vec, Integer>, Double>> nearBy = this.vc.search(x, this.bandwidth * this.kf.cutOff());
        for (VecPaired<VecPaired<Vec, Integer>, Double> result : nearBy) {
            result.setPair(result.getPair() / this.bandwidth);
        }
        return nearBy;
    }

    @Override
    public double pdf(Vec x) {
        List<? extends VecPaired<VecPaired<Vec, Integer>, Double>> nearBy = this.getNearby(x);
        if (nearBy.isEmpty()) {
            return 0.0;
        }
        double PDF = 0.0;
        for (VecPaired<VecPaired<Vec, Integer>, Double> vecPaired : nearBy) {
            PDF += vecPaired.getPair().doubleValue();
        }
        return PDF / ((double)this.vc.size() * Math.pow(this.bandwidth, nearBy.get(0).length()));
    }

    public <V extends Vec> boolean setUsingData(List<V> dataSet, double bandwith) {
        return this.setUsingData(dataSet, bandwith, null);
    }

    public <V extends Vec> boolean setUsingData(List<V> dataSet, double bandwith, ExecutorService threadpool) {
        this.setBandwith(bandwith);
        ArrayList<VecPaired<Vec, Integer>> indexVectorPair = new ArrayList<VecPaired<Vec, Integer>>(dataSet.size());
        for (int i = 0; i < dataSet.size(); ++i) {
            indexVectorPair.add(new VecPaired<Vec, Integer>((Vec)dataSet.get(i), i));
        }
        TrainableDistanceMetric.trainIfNeeded(this.distanceMetric, dataSet, threadpool);
        this.vc.build(threadpool != null, indexVectorPair, this.distanceMetric);
        return true;
    }

    public <V extends Vec> boolean setUsingData(List<V> dataSet, int k) {
        return this.setUsingData(dataSet, k, this.defaultStndDev);
    }

    public <V extends Vec> boolean setUsingData(List<V> dataSet, int k, ExecutorService threadpool) {
        return this.setUsingData(dataSet, k, this.defaultStndDev, threadpool);
    }

    public <V extends Vec> boolean setUsingData(List<V> dataSet, int k, double stndDevs) {
        return this.setUsingData(dataSet, k, stndDevs, null);
    }

    public <V extends Vec> boolean setUsingData(List<V> dataSet, int k, double stndDevs, ExecutorService threadpool) {
        ArrayList<VecPaired<Vec, Integer>> indexVectorPair = new ArrayList<VecPaired<Vec, Integer>>(dataSet.size());
        for (int i = 0; i < dataSet.size(); ++i) {
            indexVectorPair.add(new VecPaired<Vec, Integer>((Vec)dataSet.get(i), i));
        }
        TrainableDistanceMetric.trainIfNeeded(this.distanceMetric, dataSet, threadpool);
        this.vc.build(indexVectorPair, this.distanceMetric);
        OnLineStatistics stats = threadpool == null ? VectorCollectionUtils.getKthNeighborStats(this.vc, dataSet, k + 1) : VectorCollectionUtils.getKthNeighborStats(this.vc, dataSet, k + 1, threadpool);
        this.setBandwith(stats.getMean() + stats.getStandardDeviation() * stndDevs);
        return true;
    }

    @Override
    public <V extends Vec> boolean setUsingData(List<V> dataSet) {
        return this.setUsingData(dataSet, this.defaultK);
    }

    @Override
    public <V extends Vec> boolean setUsingData(List<V> dataSet, ExecutorService threadpool) {
        return this.setUsingData(dataSet, this.defaultK, threadpool);
    }

    @Override
    public boolean setUsingDataList(List<DataPoint> dataPoints) {
        ArrayList<Vec> dataSet = new ArrayList<Vec>(dataPoints.size());
        for (DataPoint dp : dataPoints) {
            dataSet.add(dp.getNumericalValues());
        }
        return this.setUsingData(dataSet);
    }

    @Override
    public boolean setUsingDataList(List<DataPoint> dataPoints, ExecutorService threadpool) {
        ArrayList<Vec> dataSet = new ArrayList<Vec>(dataPoints.size());
        for (DataPoint dp : dataPoints) {
            dataSet.add(dp.getNumericalValues());
        }
        return this.setUsingData(dataSet, threadpool);
    }

    @Override
    public List<Vec> sample(int count, Random rand) {
        throw new UnsupportedOperationException("Not supported yet.");
    }

    @Override
    public KernelFunction getKernelFunction() {
        return this.kf;
    }

    public void setKernelFunction(KernelFunction kf) {
        this.kf = kf;
    }

    @Override
    public void scaleBandwidth(double scale) {
        this.bandwidth *= scale;
    }
}

