/*
 * Decompiled with CFR 0.152.
 */
package jsat.clustering;

import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import jsat.DataSet;
import jsat.classifiers.DataPoint;
import jsat.clustering.ClustererBase;
import jsat.linear.Vec;
import jsat.linear.VecPaired;
import jsat.linear.distancemetrics.DistanceMetric;
import jsat.linear.distancemetrics.EuclideanDistance;
import jsat.linear.distancemetrics.TrainableDistanceMetric;
import jsat.linear.vectorcollection.DefaultVectorCollection;
import jsat.linear.vectorcollection.VectorCollection;
import jsat.linear.vectorcollection.VectorCollectionUtils;
import jsat.math.OnLineStatistics;
import jsat.utils.concurrent.ParallelUtils;

public class DBSCAN
extends ClustererBase {
    private static final long serialVersionUID = 1627963360642560455L;
    private static final int UNCLASSIFIED = -1;
    private static final int NOISE = -2;
    private VectorCollection<VecPaired<Vec, Integer>> vc;
    private DistanceMetric dm;
    private double stndDevs = 2.0;

    public DBSCAN(DistanceMetric dm, VectorCollection<VecPaired<Vec, Integer>> vc) {
        this.dm = dm;
        this.vc = vc;
    }

    public DBSCAN() {
        this(new EuclideanDistance());
    }

    public DBSCAN(DistanceMetric dm) {
        this(dm, new DefaultVectorCollection<VecPaired<Vec, Integer>>());
    }

    public DBSCAN(DBSCAN toCopy) {
        this.vc = toCopy.vc.clone();
        this.dm = toCopy.dm.clone();
        this.stndDevs = toCopy.stndDevs;
    }

    public List<List<DataPoint>> cluster(DataSet dataSet, int minPts) {
        return DBSCAN.createClusterListFromAssignmentArray(this.cluster(dataSet, minPts, (int[])null), dataSet);
    }

    public int[] cluster(DataSet dataSet, int minPts, int[] designations) {
        return this.cluster(dataSet, minPts, false, designations);
    }

    @Override
    public int[] cluster(DataSet dataSet, boolean parallel, int[] designations) {
        return this.cluster(dataSet, 3, parallel, designations);
    }

    @Override
    public DBSCAN clone() {
        return new DBSCAN(this);
    }

    public List<List<DataPoint>> cluster(DataSet dataSet, int minPts, boolean parallel) {
        return DBSCAN.createClusterListFromAssignmentArray(this.cluster(dataSet, minPts, parallel, null), dataSet);
    }

    public int[] cluster(DataSet dataSet, int minPts, boolean parallel, int[] designations) {
        TrainableDistanceMetric.trainIfNeeded(this.dm, dataSet, parallel);
        this.vc.build(parallel, this.getVecIndexPairs(dataSet), this.dm);
        OnLineStatistics stats = ParallelUtils.run(parallel, dataSet.getSampleSize(), (start, end) -> {
            OnLineStatistics s = new OnLineStatistics();
            for (int i = start; i < end; ++i) {
                DataPoint dp = dataSet.getDataPoint(i);
                s.add(this.vc.search(dp.getNumericalValues(), minPts + 1).get(minPts).getPair());
            }
            return s;
        }, (t, u) -> t.apply((OnLineStatistics)t, (OnLineStatistics)u));
        double eps = stats.getMean() + stats.getStandardDeviation() * this.stndDevs;
        return this.cluster(dataSet, eps, minPts, this.vc, parallel, designations);
    }

    private List<VecPaired<Vec, Integer>> getVecIndexPairs(DataSet dataSet) {
        ArrayList<VecPaired<Vec, Integer>> vecs = new ArrayList<VecPaired<Vec, Integer>>(dataSet.getSampleSize());
        for (int i = 0; i < dataSet.getSampleSize(); ++i) {
            vecs.add(new VecPaired<Vec, Integer>(dataSet.getDataPoint(i).getNumericalValues(), i));
        }
        return vecs;
    }

    public List<List<DataPoint>> cluster(DataSet dataSet, double eps, int minPts) {
        return DBSCAN.createClusterListFromAssignmentArray(this.cluster(dataSet, eps, minPts, (int[])null), dataSet);
    }

    public int[] cluster(DataSet dataSet, double eps, int minPts, int[] designations) {
        TrainableDistanceMetric.trainIfNeeded(this.dm, dataSet);
        return this.cluster(dataSet, eps, minPts, this.vc, false, designations);
    }

    public List<List<DataPoint>> cluster(DataSet dataSet, double eps, int minPts, boolean parallel) {
        return DBSCAN.createClusterListFromAssignmentArray(this.cluster(dataSet, eps, minPts, parallel, null), dataSet);
    }

    public int[] cluster(DataSet dataSet, double eps, int minPts, boolean parallel, int[] designations) {
        TrainableDistanceMetric.trainIfNeeded(this.dm, dataSet, parallel);
        return this.cluster(dataSet, eps, minPts, this.vc, parallel, designations);
    }

    private int[] cluster(DataSet dataSet, double eps, int minPts, VectorCollection<VecPaired<Vec, Integer>> vc, boolean parallel, int[] pointCats) {
        if (pointCats == null) {
            pointCats = new int[dataSet.getSampleSize()];
        }
        Arrays.fill(pointCats, -1);
        vc.build(parallel, this.getVecIndexPairs(dataSet), this.dm);
        List<List<? extends VecPaired<VecPaired<Vec, Integer>, Double>>> allNearestNeighbor = VectorCollectionUtils.allEpsNeighbors(vc, dataSet.getDataVectors(), eps, parallel);
        int curClusterID = 0;
        for (int i = 0; i < pointCats.length; ++i) {
            if (pointCats[i] != -1 || !this.expandCluster(pointCats, dataSet, i, curClusterID, eps, minPts, allNearestNeighbor)) continue;
            ++curClusterID;
        }
        return pointCats;
    }

    private boolean expandCluster(int[] pointCats, DataSet dataSet, int point, int clId, double eps, int minPts, List<List<? extends VecPaired<VecPaired<Vec, Integer>, Double>>> allNearNeighbors) {
        List<? extends VecPaired<VecPaired<Vec, Integer>, Double>> seeds = allNearNeighbors.get(point);
        if (seeds.size() < minPts) {
            pointCats[point] = -2;
            return false;
        }
        pointCats[point] = clId;
        ArrayDeque<? extends VecPaired<VecPaired<Vec, Integer>, Double>> workQue = new ArrayDeque<VecPaired<VecPaired<Vec, Integer>, Double>>(seeds);
        while (!workQue.isEmpty()) {
            VecPaired currentP = (VecPaired)workQue.poll();
            List<? extends VecPaired<VecPaired<Vec, Integer>, Double>> results = allNearNeighbors.get((Integer)((VecPaired)currentP.getVector()).getPair());
            if (results.size() < minPts) continue;
            for (VecPaired<VecPaired<Vec, Integer>, Double> vecPaired : results) {
                int resultPIndx = vecPaired.getVector().getPair();
                if (pointCats[resultPIndx] >= 0) continue;
                if (pointCats[resultPIndx] == -1) {
                    workQue.add(vecPaired);
                }
                pointCats[resultPIndx] = clId;
            }
        }
        return true;
    }
}

