/*
 * Decompiled with CFR 0.152.
 */
package jsat.regression;

import java.util.ArrayList;
import java.util.List;
import jsat.classifiers.DataPoint;
import jsat.distributions.multivariate.MultivariateKDE;
import jsat.linear.Vec;
import jsat.linear.VecPaired;
import jsat.parameters.Parameter;
import jsat.parameters.Parameterized;
import jsat.regression.RegressionDataSet;
import jsat.regression.Regressor;

public class NadarayaWatson
implements Regressor,
Parameterized {
    private static final long serialVersionUID = 8632599345930394763L;
    @Parameter.ParameterHolder
    private MultivariateKDE kde;

    public NadarayaWatson(MultivariateKDE kde) {
        this.kde = kde;
    }

    @Override
    public double regress(DataPoint data) {
        List<? extends VecPaired<VecPaired<Vec, Integer>, Double>> nearBy = this.kde.getNearby(data.getNumericalValues());
        if (nearBy.isEmpty()) {
            return 0.0;
        }
        double weightSum = 0.0;
        double sum = 0.0;
        for (VecPaired<VecPaired<Vec, Integer>, Double> vecPaired : nearBy) {
            double weight = vecPaired.getPair();
            double regressionValue = (Double)((VecPaired)vecPaired.getVector().getVector()).getPair();
            weightSum += weight;
            sum += weight * regressionValue;
        }
        return sum / weightSum;
    }

    @Override
    public void train(RegressionDataSet dataSet, boolean parallel) {
        List<VecPaired<Vec, Double>> vectors = this.collectVectors(dataSet);
        this.kde.setUsingData(vectors, parallel);
    }

    private List<VecPaired<Vec, Double>> collectVectors(RegressionDataSet dataSet) {
        ArrayList<VecPaired<Vec, Double>> vectors = new ArrayList<VecPaired<Vec, Double>>(dataSet.getSampleSize());
        for (int i = 0; i < dataSet.getSampleSize(); ++i) {
            vectors.add(new VecPaired<Vec, Double>(dataSet.getDataPoint(i).getNumericalValues(), dataSet.getTargetValue(i)));
        }
        return vectors;
    }

    @Override
    public boolean supportsWeightedData() {
        return true;
    }

    @Override
    public NadarayaWatson clone() {
        return new NadarayaWatson(this.kde.clone());
    }
}

