/*
 * Decompiled with CFR 0.152.
 */
package jsat.distributions.kernels;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import jsat.DataSet;
import jsat.classifiers.ClassificationDataSet;
import jsat.distributions.Distribution;
import jsat.distributions.LogUniform;
import jsat.distributions.kernels.DistanceMetricBasedKernel;
import jsat.distributions.kernels.KernelTrick;
import jsat.linear.Vec;
import jsat.linear.distancemetrics.DistanceMetric;
import jsat.utils.DoubleList;
import jsat.utils.IntList;
import jsat.utils.ListUtils;

public class GeneralRBFKernel
extends DistanceMetricBasedKernel {
    private static final long serialVersionUID = 1368225926995372017L;
    private double sigma;
    private double sigmaSqrd2Inv;

    public GeneralRBFKernel(DistanceMetric d, double sigma) {
        super(d);
        this.setSigma(sigma);
    }

    public void setSigma(double sigma) {
        if (sigma <= 0.0 || Double.isNaN(sigma) || Double.isInfinite(sigma)) {
            throw new IllegalArgumentException("Sigma must be a positive constant, not " + sigma);
        }
        this.sigma = sigma;
        this.sigmaSqrd2Inv = 0.5 / (sigma * sigma);
    }

    public double getSigma() {
        return this.sigma;
    }

    @Override
    public KernelTrick clone() {
        return new GeneralRBFKernel(this.d.clone(), this.sigma);
    }

    @Override
    public double eval(Vec a, Vec b) {
        double dist = this.d.dist(a, b);
        return Math.exp(-dist * dist * this.sigmaSqrd2Inv);
    }

    @Override
    public double eval(int a, Vec b, List<Double> qi, List<? extends Vec> vecs, List<Double> cache) {
        double dist = this.d.dist(a, b, qi, vecs, cache);
        return Math.exp(-dist * dist * this.sigmaSqrd2Inv);
    }

    @Override
    public double eval(int a, int b, List<? extends Vec> vecs, List<Double> cache) {
        double dist = this.d.dist(a, b, vecs, cache);
        return Math.exp(-dist * dist * this.sigmaSqrd2Inv);
    }

    public Distribution guessSigma(DataSet d) {
        return GeneralRBFKernel.guessSigma(d, this.d);
    }

    public static Distribution guessSigma(DataSet d, DistanceMetric dist) {
        List<Vec> allVecs = d.getDataVectors();
        int toSample = d.getSampleSize();
        if (toSample > 5000) {
            toSample = 5000 + (int)Math.floor(Math.sqrt(d.getSampleSize() - 5000));
        }
        DoubleList vals = new DoubleList(toSample * toSample);
        if (d instanceof ClassificationDataSet && ((ClassificationDataSet)d).getPredicting().getNumOfCategories() == 2) {
            ClassificationDataSet cdata = (ClassificationDataSet)d;
            ArrayList<Vec> class0 = new ArrayList<Vec>(toSample / 2);
            ArrayList<Vec> class1 = new ArrayList<Vec>(toSample / 2);
            IntList randOrder = new IntList(d.getSampleSize());
            ListUtils.addRange(randOrder, 0, d.getSampleSize(), 1);
            Collections.shuffle(randOrder);
            for (int i = 0; i < randOrder.size(); ++i) {
                int indx = randOrder.getI(i);
                if (cdata.getDataPointCategory(indx) == 0 && class0.size() < toSample / 2) {
                    class0.add(cdata.getDataPoint(indx).getNumericalValues());
                    continue;
                }
                if (cdata.getDataPointCategory(indx) != 1 || class0.size() >= toSample / 2) continue;
                class1.add(cdata.getDataPoint(indx).getNumericalValues());
            }
            int j_start = class0.size();
            class0.addAll(class1);
            List<Double> cache = dist.getAccelerationCache(class0);
            for (int i = 0; i < j_start; ++i) {
                for (int j = j_start; j < class0.size(); ++j) {
                    vals.add(dist.dist(i, j, allVecs, cache));
                }
            }
        } else {
            Collections.shuffle(allVecs);
            if (d.getSampleSize() > 5000) {
                allVecs = allVecs.subList(0, toSample);
            }
            List<Double> cache = dist.getAccelerationCache(allVecs);
            for (int i = 0; i < allVecs.size(); ++i) {
                for (int j = i + 1; j < allVecs.size(); ++j) {
                    vals.add(dist.dist(i, j, allVecs, cache));
                }
            }
        }
        Collections.sort(vals);
        double median = vals.get(vals.size() / 2);
        return new LogUniform(Math.exp(Math.log(median) - 4.0), Math.exp(Math.log(median) + 4.0));
    }

    @Override
    public boolean normalized() {
        return true;
    }
}

