/*
 * Decompiled with CFR 0.152.
 */
package jsat.linear.vectorcollection;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.ExecutorService;
import jsat.linear.IndexValue;
import jsat.linear.Vec;
import jsat.linear.distancemetrics.ChebyshevDistance;
import jsat.linear.distancemetrics.DistanceMetric;
import jsat.linear.distancemetrics.EuclideanDistance;
import jsat.linear.distancemetrics.ManhattanDistance;
import jsat.linear.distancemetrics.MinkowskiDistance;
import jsat.linear.vectorcollection.IncrementalCollection;
import jsat.linear.vectorcollection.IndexDistPair;
import jsat.math.FastMath;
import jsat.math.OnLineStatistics;
import jsat.utils.BoundedSortedList;
import jsat.utils.DoubleList;
import jsat.utils.IndexTable;
import jsat.utils.IntList;
import jsat.utils.ListUtils;
import jsat.utils.ModifiableCountDownLatch;
import jsat.utils.concurrent.ParallelUtils;

public class KDTree<V extends Vec>
implements IncrementalCollection<V> {
    private static final long serialVersionUID = -7401342201406776463L;
    private DistanceMetric distanceMetric;
    private KDNode root;
    private PivotSelection pvSelection;
    private int size;
    private int leaf_node_size = 20;
    private List<V> allVecs;
    private List<Double> distCache;

    public KDTree(List<V> vecs, DistanceMetric distanceMetric, PivotSelection pvSelection, boolean parallel) {
        this.distanceMetric = distanceMetric;
        this.pvSelection = pvSelection;
        this.build(parallel, vecs, distanceMetric);
    }

    public KDTree(List<V> vecs, DistanceMetric distanceMetric, PivotSelection pvSelection) {
        this(vecs, distanceMetric, pvSelection, false);
    }

    public KDTree(List<V> vecs, DistanceMetric distanceMetric) {
        this(vecs, distanceMetric, PivotSelection.SPREAD_MEDOID);
    }

    private KDTree(DistanceMetric distanceMetric, PivotSelection pvSelection) {
        this.setDistanceMetric(distanceMetric);
        this.pvSelection = pvSelection;
    }

    public KDTree(PivotSelection pivotSelection) {
        this(new EuclideanDistance(), pivotSelection);
    }

    public KDTree() {
        this(PivotSelection.SPREAD_MEDOID);
    }

    public void setLeafSize(int leaf_size) {
        if (leaf_size < 2) {
            throw new IllegalArgumentException("The leaf size must be >= 2 to support all splitting methods");
        }
        this.leaf_node_size = leaf_size;
    }

    public int getLeafSize() {
        return this.leaf_node_size;
    }

    @Override
    public void setDistanceMetric(DistanceMetric dm) {
        if (!(dm instanceof EuclideanDistance || dm instanceof ChebyshevDistance || dm instanceof ManhattanDistance || dm instanceof MinkowskiDistance)) {
            throw new ArithmeticException("KD Trees are not compatible with the given distance metric.");
        }
        this.distanceMetric = dm;
    }

    @Override
    public DistanceMetric getDistanceMetric() {
        return this.distanceMetric;
    }

    @Override
    public void build(boolean parallel, List<V> vecs, DistanceMetric dm) {
        this.setDistanceMetric(dm);
        this.size = vecs.size();
        vecs = new ArrayList<V>(vecs);
        this.allVecs = vecs;
        this.distCache = this.distanceMetric.getAccelerationCache(vecs, parallel);
        IntList vecIndices = new IntList(this.size);
        ListUtils.addRange(vecIndices, 0, this.size, 1);
        if (!parallel) {
            this.root = this.buildTree(vecIndices, 0, null, null);
        } else {
            ModifiableCountDownLatch mcdl = new ModifiableCountDownLatch(1);
            this.root = this.buildTree(vecIndices, 0, ParallelUtils.CACHED_THREAD_POOL, mcdl);
            try {
                mcdl.await();
            }
            catch (InterruptedException ex) {
                this.root = this.buildTree(vecIndices, 0, null, null);
            }
        }
    }

    @Override
    public void insert(V x) {
        if (this.allVecs == null) {
            this.allVecs = new ArrayList<V>();
            this.distCache = this.distanceMetric.getAccelerationCache(this.allVecs);
            this.size = 0;
            this.root = new KDLeaf(0, new IntList());
        }
        int indx = this.size++;
        this.allVecs.add(x);
        if (this.distCache != null) {
            this.distCache.addAll(this.distanceMetric.getQueryInfo((Vec)x));
        }
        if (this.root.insert(indx)) {
            this.root = this.buildTree(IntList.range(this.size), 0, null, null);
        }
    }

    private KDNode buildTree(List<Integer> data, int depth, ExecutorService threadpool, ModifiableCountDownLatch mcdl) {
        if (data == null || data.isEmpty()) {
            if (threadpool != null) {
                mcdl.countDown();
            }
            return null;
        }
        int mod = ((Vec)this.allVecs.get(0)).length();
        if (data.size() <= this.leaf_node_size) {
            if (threadpool != null) {
                mcdl.countDown();
            }
            return new KDLeaf(depth % mod, data);
        }
        int pivot = -1;
        double pivot_val = Double.NaN;
        switch (this.pvSelection) {
            case VARIANCE: {
                OnLineStatistics[] allStats = new OnLineStatistics[mod];
                for (int j = 0; j < allStats.length; ++j) {
                    allStats[j] = new OnLineStatistics();
                }
                for (int i : data) {
                    V vec = this.get(i);
                    for (int j = 0; j < allStats.length; ++j) {
                        allStats[j].add(((Vec)vec).get(j));
                    }
                }
                double maxVariance = -1.0;
                for (int j = 0; j < allStats.length; ++j) {
                    if (!(allStats[j].getVarance() > maxVariance)) continue;
                    maxVariance = allStats[j].getVarance();
                    pivot = j;
                }
                if (pivot >= 0) break;
                pivot = depth % mod;
                break;
            }
            case SPREAD_MEDOID: {
                double[] mins = new double[mod];
                double[] maxs = new double[mod];
                Arrays.fill(mins, Double.POSITIVE_INFINITY);
                Arrays.fill(maxs, Double.NEGATIVE_INFINITY);
                for (int i : data) {
                    V v = this.get(i);
                    for (IndexValue iv : v) {
                        int d = iv.getIndex();
                        double val = iv.getValue();
                        mins[d] = Math.min(mins[d], val);
                        maxs[d] = Math.max(maxs[d], val);
                    }
                }
                int maxSpreadDim = 0;
                double maxSpreadVal = maxs[0] - mins[0];
                for (int d = 0; d < mod; ++d) {
                    double v = maxs[d] - mins[d];
                    if (!(v > maxSpreadVal)) continue;
                    maxSpreadDim = d;
                    maxSpreadVal = v;
                }
                pivot = maxSpreadDim;
                double midPoint = (maxs[maxSpreadDim] - mins[maxSpreadDim]) / 2.0;
                double closestVal = maxs[maxSpreadDim];
                for (int i = 0; i < data.size(); ++i) {
                    V v = this.get(i);
                    double val = ((Vec)v).get(maxSpreadDim);
                    if (!(Math.abs(midPoint - val) < Math.abs(midPoint - closestVal))) continue;
                    closestVal = val;
                }
                pivot_val = closestVal;
                break;
            }
            default: {
                pivot = depth % mod;
            }
        }
        KDNode node = new KDNode(pivot);
        int splitIndex = -1;
        if (!Double.isNaN(pivot_val)) {
            int front = 0;
            for (int i = 0; i < data.size(); ++i) {
                if (!(((Vec)this.get(data.get(i))).get(pivot) <= pivot_val)) continue;
                ListUtils.swap(data, front++, i);
            }
            if (FastMath.floor_log2(this.allVecs.size()) >= depth && front < this.leaf_node_size / 3 || data.size() - front < this.leaf_node_size / 3) {
                pivot_val = Double.NaN;
            } else {
                splitIndex = front - 1;
                node.pivot_s = pivot_val;
            }
        }
        if (Double.isNaN(pivot_val)) {
            Collections.sort(data, new VecIndexComparator(pivot));
            splitIndex = this.getMedianIndex(data, pivot);
            if (splitIndex == data.size() - 1) {
                return new KDLeaf(depth % mod, data);
            }
            node.pivot_s = pivot_val = ((Vec)this.get(data.get(splitIndex))).get(pivot);
        }
        if (threadpool == null) {
            node.setLeft(this.buildTree(data.subList(0, splitIndex + 1), depth + 1, threadpool, mcdl));
            node.setRight(this.buildTree(data.subList(splitIndex + 1, data.size()), depth + 1, threadpool, mcdl));
        } else {
            mcdl.countUp();
            IntList data_l = new IntList(data.subList(0, splitIndex + 1));
            IntList data_r = new IntList(data.subList(splitIndex + 1, data.size()));
            threadpool.submit(() -> node.setRight(this.buildTree(data_r, depth + 1, threadpool, mcdl)));
            node.setLeft(this.buildTree(data_l, depth + 1, threadpool, mcdl));
        }
        return node;
    }

    public int getMedianIndex(List<Integer> data, int pivot) {
        int medianIndex;
        for (medianIndex = data.size() / 2; medianIndex < data.size() - 1 && ((Vec)this.allVecs.get(data.get(medianIndex))).get(pivot) == ((Vec)this.allVecs.get(data.get(medianIndex + 1))).get(pivot); ++medianIndex) {
        }
        return medianIndex;
    }

    @Override
    public void search(Vec query, int numNeighbors, List<Integer> neighbors, List<Double> distances) {
        if (numNeighbors < 1) {
            throw new RuntimeException("Invalid number of neighbors to search for");
        }
        BoundedSortedList<IndexDistPair> knns = new BoundedSortedList<IndexDistPair>(numNeighbors);
        this.root.searchK(numNeighbors, knns, query, this.distanceMetric.getQueryInfo(query));
        neighbors.clear();
        distances.clear();
        for (int i = 0; i < knns.size(); ++i) {
            IndexDistPair pm = (IndexDistPair)knns.get(i);
            neighbors.add(pm.getIndex());
            distances.add(pm.getDist());
        }
    }

    @Override
    public int size() {
        return this.size;
    }

    @Override
    public V get(int indx) {
        return (V)((Vec)this.allVecs.get(indx));
    }

    @Override
    public void search(Vec query, double range, List<Integer> neighbors, List<Double> distances) {
        if (range <= 0.0) {
            throw new RuntimeException("Range must be a positive number");
        }
        neighbors.clear();
        distances.clear();
        List<Double> qi = this.distanceMetric.getQueryInfo(query);
        this.root.searchR(range, neighbors, distances, query, qi);
        IndexTable it = new IndexTable(distances);
        it.apply(neighbors);
        it.apply(distances);
    }

    @Override
    public KDTree<V> clone() {
        KDTree<V> clone = new KDTree<V>(this.distanceMetric, this.pvSelection);
        if (this.distCache != null) {
            clone.distCache = new DoubleList(this.distCache);
        }
        if (this.allVecs != null) {
            clone.allVecs = new ArrayList<V>(this.allVecs);
        }
        clone.size = this.size;
        if (this.root != null) {
            clone.root = this.root.clone();
        }
        return clone;
    }

    private class VecIndexComparator
    implements Comparator<Integer> {
        private final int index;

        public VecIndexComparator(int index) {
            this.index = index;
        }

        @Override
        public int compare(Integer o1, Integer o2) {
            return Double.compare(((Vec)KDTree.this.allVecs.get(o1)).get(this.index), ((Vec)KDTree.this.allVecs.get(o2)).get(this.index));
        }
    }

    private class KDLeaf
    extends KDNode {
        protected IntList owned;

        public KDLeaf(int axis, List<Integer> toOwn) {
            super(axis);
            this.owned = new IntList(toOwn);
        }

        public KDLeaf(KDLeaf toCopy) {
            super(toCopy);
            this.owned = new IntList(toCopy.owned);
        }

        @Override
        protected void searchK(int k, BoundedSortedList<IndexDistPair> knn, Vec target, List<Double> qi) {
            Iterator iterator = this.owned.iterator();
            while (iterator.hasNext()) {
                int i = (Integer)iterator.next();
                double dist = KDTree.this.distanceMetric.dist(i, target, qi, KDTree.this.allVecs, KDTree.this.distCache);
                knn.add(new IndexDistPair(i, dist));
            }
        }

        @Override
        protected void searchR(double radius, List<Integer> vecsInRage, List<Double> distVecsInRange, Vec target, List<Double> qi) {
            Iterator iterator = this.owned.iterator();
            while (iterator.hasNext()) {
                int i = (Integer)iterator.next();
                double dist = KDTree.this.distanceMetric.dist(i, target, qi, KDTree.this.allVecs, KDTree.this.distCache);
                if (!(dist <= radius)) continue;
                vecsInRage.add(i);
                distVecsInRange.add(dist);
            }
        }

        @Override
        protected boolean insert(int x_indx) {
            this.owned.add(x_indx);
            return this.owned.size() >= KDTree.this.leaf_node_size * 2;
        }

        @Override
        protected KDLeaf clone() {
            return new KDLeaf(this);
        }
    }

    private class KDNode
    implements Cloneable,
    Serializable {
        protected int axis;
        protected double pivot_s;
        protected KDNode left;
        protected KDNode right;

        public KDNode(int axis) {
            this.axis = axis;
        }

        public KDNode(KDNode toCopy) {
            this(toCopy.axis);
            this.pivot_s = toCopy.pivot_s;
            if (toCopy.left != null) {
                this.left = toCopy.left.clone();
            }
            if (toCopy.left != null) {
                this.right = toCopy.right.clone();
            }
        }

        public void setAxis(int axis) {
            this.axis = axis;
        }

        public void setLeft(KDNode left) {
            this.left = left;
        }

        public void setRight(KDNode right) {
            this.right = right;
        }

        public int getAxis() {
            return this.axis;
        }

        public KDNode getLeft() {
            return this.left;
        }

        public KDNode getRight() {
            return this.right;
        }

        protected KDNode clone() {
            return new KDNode(this);
        }

        protected void searchK(int k, BoundedSortedList<IndexDistPair> knn, Vec target, List<Double> qi) {
            KDNode farKD;
            KDNode nearKD;
            boolean target_in_left;
            double target_s = target.get(this.axis);
            boolean bl = target_in_left = target_s <= this.pivot_s;
            if (target_in_left) {
                nearKD = this.left;
                farKD = this.right;
            } else {
                nearKD = this.right;
                farKD = this.left;
            }
            nearKD.searchK(k, knn, target, qi);
            double maxDistSoFar = Double.MAX_VALUE;
            if (knn.size() >= k) {
                maxDistSoFar = ((IndexDistPair)knn.get(k - 1)).getDist();
            }
            if (maxDistSoFar > Math.abs(target_s - this.pivot_s)) {
                farKD.searchK(k, knn, target, qi);
            }
        }

        protected void searchR(double radius, List<Integer> vecsInRage, List<Double> distVecsInRange, Vec target, List<Double> qi) {
            double target_s = target.get(this.axis);
            if (radius > target_s - this.pivot_s) {
                this.left.searchR(radius, vecsInRage, distVecsInRange, target, qi);
            }
            if (radius > this.pivot_s - target_s) {
                this.right.searchR(radius, vecsInRage, distVecsInRange, target, qi);
            }
        }

        protected boolean insert(int x_indx) {
            boolean target_in_left;
            double target_s = ((Vec)KDTree.this.get(x_indx)).get(this.axis);
            boolean bl = target_in_left = target_s <= this.pivot_s;
            if (target_in_left) {
                if (this.left.insert(x_indx)) {
                    this.left = KDTree.this.buildTree(((KDLeaf)this.left).owned, this.axis + 1, null, null);
                }
            } else if (this.right.insert(x_indx)) {
                this.right = KDTree.this.buildTree(((KDLeaf)this.right).owned, this.axis + 1, null, null);
            }
            return false;
        }
    }

    public static enum PivotSelection {
        INCREMENTAL,
        VARIANCE,
        SPREAD_MEDOID;

    }
}

