/*
 * Decompiled with CFR 0.152.
 */
package jsat.clustering;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import java.util.concurrent.atomic.AtomicInteger;
import jsat.DataSet;
import jsat.linear.MatrixStatistics;
import jsat.linear.Vec;
import jsat.linear.distancemetrics.DistanceMetric;
import jsat.utils.IndexTable;
import jsat.utils.IntSet;
import jsat.utils.concurrent.ParallelUtils;

public class SeedSelectionMethods {
    private SeedSelectionMethods() {
    }

    public static List<Vec> selectIntialPoints(DataSet d, int k, DistanceMetric dm, Random rand, SeedSelection selectionMethod) {
        return SeedSelectionMethods.selectIntialPoints(d, k, dm, null, rand, selectionMethod);
    }

    public static List<Vec> selectIntialPoints(DataSet d, int k, DistanceMetric dm, List<Double> accelCache, Random rand, SeedSelection selectionMethod) {
        int[] indicies = new int[k];
        SeedSelectionMethods.selectIntialPoints(d, indicies, dm, accelCache, rand, selectionMethod, false);
        ArrayList<Vec> vecs = new ArrayList<Vec>(k);
        int[] nArray = indicies;
        int n = nArray.length;
        for (int i = 0; i < n; ++i) {
            Integer i2 = nArray[i];
            vecs.add(d.getDataPoint(i2).getNumericalValues().clone());
        }
        return vecs;
    }

    public static List<Vec> selectIntialPoints(DataSet d, int k, DistanceMetric dm, Random rand, SeedSelection selectionMethod, boolean parallel) {
        return SeedSelectionMethods.selectIntialPoints(d, k, dm, null, rand, selectionMethod, parallel);
    }

    public static List<Vec> selectIntialPoints(DataSet d, int k, DistanceMetric dm, List<Double> accelCache, Random rand, SeedSelection selectionMethod, boolean parallel) {
        int[] indicies = new int[k];
        SeedSelectionMethods.selectIntialPoints(d, indicies, dm, accelCache, rand, selectionMethod, parallel);
        ArrayList<Vec> vecs = new ArrayList<Vec>(k);
        int[] nArray = indicies;
        int n = nArray.length;
        for (int i = 0; i < n; ++i) {
            Integer i2 = nArray[i];
            vecs.add(d.getDataPoint(i2).getNumericalValues().clone());
        }
        return vecs;
    }

    public static void selectIntialPoints(DataSet d, int[] indices, DistanceMetric dm, Random rand, SeedSelection selectionMethod) {
        SeedSelectionMethods.selectIntialPoints(d, indices, dm, null, rand, selectionMethod);
    }

    public static void selectIntialPoints(DataSet d, int[] indices, DistanceMetric dm, List<Double> accelCache, Random rand, SeedSelection selectionMethod) {
        SeedSelectionMethods.selectIntialPoints(d, indices, dm, accelCache, rand, selectionMethod, false);
    }

    public static void selectIntialPoints(DataSet d, int[] indices, DistanceMetric dm, Random rand, SeedSelection selectionMethod, boolean parallel) {
        SeedSelectionMethods.selectIntialPoints(d, indices, dm, null, rand, selectionMethod, parallel);
    }

    public static void selectIntialPoints(DataSet d, int[] indices, DistanceMetric dm, List<Double> accelCache, Random rand, SeedSelection selectionMethod, boolean parallel) {
        int k = indices.length;
        if (null != selectionMethod) {
            switch (selectionMethod) {
                case RANDOM: {
                    IntSet indecies = new IntSet(k);
                    while (indecies.size() != k) {
                        indecies.add(Integer.valueOf(rand.nextInt(d.getSampleSize())));
                    }
                    int j = 0;
                    for (Integer i : indecies) {
                        indices[j++] = i;
                    }
                    break;
                }
                case KPP: {
                    SeedSelectionMethods.kppSelection(indices, rand, d, k, dm, accelCache, parallel);
                    break;
                }
                case FARTHEST_FIRST: {
                    SeedSelectionMethods.ffSelection(indices, rand, d, k, dm, accelCache, parallel);
                    break;
                }
                case MEAN_QUANTILES: {
                    SeedSelectionMethods.mqSelection(indices, d, k, dm, accelCache, parallel);
                    break;
                }
            }
        }
    }

    private static void kppSelection(int[] indices, Random rand, DataSet d, int k, DistanceMetric dm, List<Double> accelCache) {
        SeedSelectionMethods.kppSelection(indices, rand, d, k, dm, accelCache, false);
    }

    private static void kppSelection(int[] indices, Random rand, DataSet d, int k, DistanceMetric dm, List<Double> accelCache, boolean parallel) {
        indices[0] = rand.nextInt(d.getSampleSize());
        double[] closestDist = new double[d.getSampleSize()];
        List<Vec> X = d.getDataVectors();
        for (int j = 1; j < k; ++j) {
            int newMeanIndx = indices[j - 1];
            boolean forceCompute = j == 1;
            double sqrdDistSum = ParallelUtils.run(parallel, X.size(), (start, end) -> {
                double sqrdDistChanges = 0.0;
                for (int i = start; i < end; ++i) {
                    double newDist = dm.dist(newMeanIndx, i, (List<? extends Vec>)X, accelCache);
                    if (!((newDist *= newDist) < closestDist[i]) && !forceCompute) continue;
                    sqrdDistChanges -= closestDist[i];
                    sqrdDistChanges += newDist;
                    closestDist[i] = newDist;
                }
                return sqrdDistChanges;
            }, (t, u) -> t + u);
            if (sqrdDistSum <= 1.0E-6) {
                IntSet ind = new IntSet();
                for (int i = 0; i < j; ++i) {
                    ind.add(Integer.valueOf(indices[i]));
                }
                while (ind.size() < k) {
                    ind.add(Integer.valueOf(rand.nextInt(closestDist.length)));
                }
                int pos = 0;
                Iterator iterator = ind.iterator();
                while (iterator.hasNext()) {
                    int i = (Integer)iterator.next();
                    indices[pos++] = i;
                }
                return;
            }
            double rndX = rand.nextDouble() * sqrdDistSum;
            int i = 0;
            for (double searchSum = closestDist[0]; searchSum < rndX && i < d.getSampleSize() - 1; searchSum += closestDist[++i]) {
            }
            indices[j] = i;
        }
    }

    private static void ffSelection(int[] indices, Random rand, DataSet d, int k, DistanceMetric dm, List<Double> accelCache, boolean parallel) {
        indices[0] = rand.nextInt(d.getSampleSize());
        double[] closestDist = new double[d.getSampleSize()];
        Arrays.fill(closestDist, Double.POSITIVE_INFINITY);
        List<Vec> X = d.getDataVectors();
        for (int j = 1; j < k; ++j) {
            int newMeanIndx = indices[j - 1];
            AtomicInteger maxDistIndx = new AtomicInteger(0);
            ParallelUtils.run(parallel, d.getSampleSize(), (start, end) -> {
                double maxDist = Double.NEGATIVE_INFINITY;
                int max = indices[0];
                for (int i = start; i < end; ++i) {
                    double newDist = dm.dist(newMeanIndx, i, (List<? extends Vec>)X, accelCache);
                    closestDist[i] = Math.min(newDist, closestDist[i]);
                    if (!(closestDist[i] > maxDist)) continue;
                    maxDist = closestDist[i];
                    max = i;
                }
                AtomicInteger atomicInteger = maxDistIndx;
                synchronized (atomicInteger) {
                    if (closestDist[max] > closestDist[maxDistIndx.get()]) {
                        maxDistIndx.set(max);
                    }
                }
            });
            indices[j] = maxDistIndx.get();
        }
    }

    private static void mqSelection(int[] indices, DataSet d, int k, DistanceMetric dm, List<Double> accelCache, boolean parallel) {
        double[] meanDist = new double[d.getSampleSize()];
        Vec newMean = MatrixStatistics.meanVector(d);
        List<Double> meanQI = dm.getQueryInfo(newMean);
        List<Vec> X = d.getDataVectors();
        ParallelUtils.run(parallel, d.getSampleSize(), (start, end) -> {
            for (int i = start; i < end; ++i) {
                meanDist[i] = dm.dist(i, newMean, meanQI, X, accelCache);
            }
        });
        IndexTable indxTbl = new IndexTable(meanDist);
        for (int l = 0; l < k; ++l) {
            indices[l] = indxTbl.index(l * d.getSampleSize() / k);
        }
    }

    public static enum SeedSelection {
        RANDOM,
        KPP,
        FARTHEST_FIRST,
        MEAN_QUANTILES;

    }
}

