/*
 * Decompiled with CFR 0.152.
 */
package jsat.linear.vectorcollection;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import java.util.concurrent.ExecutorService;
import java.util.logging.Level;
import java.util.logging.Logger;
import jsat.linear.Vec;
import jsat.linear.VecPaired;
import jsat.linear.distancemetrics.DistanceMetric;
import jsat.linear.distancemetrics.EuclideanDistance;
import jsat.linear.vectorcollection.IncrementalCollection;
import jsat.linear.vectorcollection.IndexDistPair;
import jsat.utils.BoundedSortedList;
import jsat.utils.DoubleList;
import jsat.utils.IndexTable;
import jsat.utils.IntList;
import jsat.utils.ModifiableCountDownLatch;
import jsat.utils.Pair;
import jsat.utils.SimpleList;
import jsat.utils.concurrent.ParallelUtils;
import jsat.utils.random.RandomUtil;

public class VPTree<V extends Vec>
implements IncrementalCollection<V> {
    private static final long serialVersionUID = -7271540108746353762L;
    private DistanceMetric dm;
    private List<Double> distCache;
    private List<V> allVecs;
    private Random rand;
    private int sampleSize;
    private int searchIterations;
    private volatile TreeNode root;
    private VPSelection vpSelection;
    private int size;
    private int maxLeafSize = 5;

    public VPTree(List<V> list, DistanceMetric dm, VPSelection vpSelection, Random rand, int sampleSize, int searchIterations, boolean parallel) {
        this.build(parallel, list, dm);
    }

    public VPTree(List<V> list, DistanceMetric dm, VPSelection vpSelection, Random rand, int sampleSize, int searchIterations) {
        this(list, dm, vpSelection, rand, sampleSize, searchIterations, false);
    }

    public VPTree(List<V> list, DistanceMetric dm, VPSelection vpSelection) {
        this(list, dm, vpSelection, RandomUtil.getRandom(), 80, 40);
    }

    public VPTree(List<V> list, DistanceMetric dm, boolean parallel) {
        this(list, dm, VPSelection.Random, RandomUtil.getRandom(), 80, 40, parallel);
    }

    public VPTree(List<V> list, DistanceMetric dm) {
        this(list, dm, VPSelection.Random);
    }

    public VPTree() {
        this(new EuclideanDistance());
    }

    public VPTree(DistanceMetric dm) {
        this(dm, VPSelection.Random);
    }

    public VPTree(DistanceMetric dm, VPSelection sampling) {
        this.dm = dm;
        if (!dm.isSubadditive()) {
            throw new RuntimeException("VPTree only supports metrics that support the triangle inequality");
        }
        this.rand = RandomUtil.getRandom();
        this.sampleSize = 80;
        this.searchIterations = 40;
        this.size = 0;
        this.vpSelection = sampling;
        this.allVecs = new ArrayList<V>();
        if (dm.supportsAcceleration()) {
            this.distCache = new DoubleList();
        }
    }

    protected VPTree(VPTree<V> toClone) {
        this.dm = toClone.dm.clone();
        this.rand = toClone.rand == null ? null : new Random(toClone.rand.nextInt());
        this.sampleSize = toClone.sampleSize;
        this.searchIterations = toClone.searchIterations;
        this.root = this.cloneChangeContext(toClone.root);
        this.vpSelection = toClone.vpSelection;
        this.size = toClone.size;
        this.maxLeafSize = toClone.maxLeafSize;
        if (toClone.allVecs != null) {
            this.allVecs = new ArrayList<V>(toClone.allVecs);
        }
        if (toClone.distCache != null) {
            this.distCache = new DoubleList(toClone.distCache);
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    public void build(boolean parallel, List<V> list, DistanceMetric dm) {
        this.setDistanceMetric(dm);
        if (!dm.isSubadditive()) {
            throw new RuntimeException("VPTree only supports metrics that support the triangle inequality");
        }
        this.rand = this.rand;
        this.size = list.size();
        this.allVecs = list;
        this.distCache = dm.getAccelerationCache(this.allVecs, parallel);
        SimpleList<Pair<Double, Integer>> tmpList = new SimpleList<Pair<Double, Integer>>(list.size());
        for (int i = 0; i < this.allVecs.size(); ++i) {
            tmpList.add(new Pair<Double, Integer>(-1.0, i));
        }
        if (!parallel) {
            this.root = this.makeVPTree(tmpList);
        } else {
            ExecutorService threadpool = ParallelUtils.getNewExecutor(parallel);
            ModifiableCountDownLatch mcdl = new ModifiableCountDownLatch(1);
            this.root = this.makeVPTree(tmpList, threadpool, mcdl);
            mcdl.countDown();
            try {
                mcdl.await();
            }
            catch (InterruptedException ex) {
                Logger.getLogger(VPTree.class.getName()).log(Level.SEVERE, null, ex);
                System.err.println("Falling back to single threaded VPTree constructor");
                tmpList.clear();
                for (int i = 0; i < list.size(); ++i) {
                    tmpList.add(new Pair<Double, Integer>(-1.0, i));
                }
                this.root = this.makeVPTree(tmpList);
            }
            finally {
                threadpool.shutdownNow();
            }
        }
    }

    @Override
    public void setDistanceMetric(DistanceMetric dm) {
        this.dm = dm;
    }

    @Override
    public DistanceMetric getDistanceMetric() {
        return this.dm;
    }

    private TreeNode cloneChangeContext(TreeNode toClone) {
        if (toClone != null) {
            if (toClone instanceof VPLeaf) {
                return new VPLeaf((VPLeaf)toClone);
            }
            return new VPNode((VPNode)toClone);
        }
        return null;
    }

    @Override
    public int size() {
        return this.size;
    }

    @Override
    public V get(int indx) {
        return (V)((Vec)this.allVecs.get(indx));
    }

    @Override
    public void insert(V x) {
        int indx = this.size++;
        this.allVecs.add(x);
        if (this.distCache != null) {
            this.distCache.addAll(this.dm.getQueryInfo((Vec)x));
        }
        if (this.root == null) {
            ArrayList<Pair<Double, Integer>> list = new ArrayList<Pair<Double, Integer>>();
            list.add(new Pair<Double, Integer>((Double)Double.MAX_VALUE, indx));
            this.root = new VPLeaf(list);
            return;
        }
        this.root.insert(indx, Double.MAX_VALUE);
        if (this.root instanceof VPLeaf) {
            VPLeaf leaf = (VPLeaf)this.root;
            if (leaf.points.size() > this.maxLeafSize * this.maxLeafSize) {
                int orig_leaf_isze = this.maxLeafSize;
                this.maxLeafSize *= this.maxLeafSize;
                ArrayList<Pair<Double, Integer>> S = new ArrayList<Pair<Double, Integer>>();
                for (int i = 0; i < leaf.points.size(); ++i) {
                    S.add(new Pair<Double, Integer>((Double)Double.MAX_VALUE, leaf.points.getI(i)));
                }
                this.root = this.makeVPTree(S);
                this.maxLeafSize = orig_leaf_isze;
            }
        }
    }

    @Override
    public void search(Vec query, double range, List<Integer> neighbors, List<Double> distances) {
        List<Double> qi = this.dm.getQueryInfo(query);
        this.root.searchRange(VecPaired.extractTrueVec(query), range, neighbors, distances, 0.0, qi);
        IndexTable it = new IndexTable(distances);
        it.apply(neighbors);
        it.apply(distances);
    }

    @Override
    public void search(Vec query, int numNeighbors, List<Integer> neighbors, List<Double> distances) {
        BoundedSortedList<IndexDistPair> boundedList = new BoundedSortedList<IndexDistPair>(numNeighbors, numNeighbors);
        List<Double> qi = this.dm.getQueryInfo(query);
        this.root.searchKNN(VecPaired.extractTrueVec(query), numNeighbors, boundedList, 0.0, qi);
        for (IndexDistPair pm : boundedList) {
            neighbors.add(pm.getIndex());
            distances.add(pm.getDist());
        }
    }

    private int sortSplitSet(List<Pair<Double, Integer>> S, VPNode node) {
        for (Pair<Double, Integer> S1 : S) {
            S1.setFirstItem(this.dm.dist(node.p, S1.getSecondItem(), this.allVecs, this.distCache));
        }
        Collections.sort(S, (o1, o2) -> Double.compare((Double)o1.getFirstItem(), (Double)o2.getFirstItem()));
        int splitIndex = this.splitListIndex(S);
        node.left_low = S.get(0).getFirstItem();
        node.left_high = S.get(splitIndex).getFirstItem();
        node.right_low = S.get(splitIndex + 1).getFirstItem();
        node.right_high = S.get(S.size() - 1).getFirstItem();
        return splitIndex;
    }

    protected int splitListIndex(List<Pair<Double, Integer>> S) {
        return S.size() / 2;
    }

    public int getMaxLeafSize() {
        return this.maxLeafSize;
    }

    public void setMaxLeafSize(int maxLeafSize) {
        this.maxLeafSize = Math.max(5, maxLeafSize);
    }

    private TreeNode makeVPTree(List<Pair<Double, Integer>> S) {
        if (S.isEmpty()) {
            return null;
        }
        if (S.size() <= this.maxLeafSize) {
            VPLeaf leaf = new VPLeaf(S);
            return leaf;
        }
        int vpIndex = this.selectVantagePointIndex(S);
        VPNode node = new VPNode(S.get(vpIndex).getSecondItem());
        Collections.swap(S, 0, vpIndex);
        int splitIndex = this.sortSplitSet(S.subList(1, S.size()), node) + 1;
        node.right = this.makeVPTree(S.subList(splitIndex + 1, S.size()));
        node.left = this.makeVPTree(S.subList(1, splitIndex + 1));
        return node;
    }

    private TreeNode makeVPTree(List<Pair<Double, Integer>> S, ExecutorService threadpool, ModifiableCountDownLatch mcdl) {
        if (S.isEmpty()) {
            return null;
        }
        if (S.size() <= this.maxLeafSize) {
            VPLeaf leaf = new VPLeaf(S);
            return leaf;
        }
        int vpIndex = this.selectVantagePointIndex(S);
        VPNode node = new VPNode(S.get(vpIndex).getSecondItem());
        Collections.swap(S, 0, vpIndex);
        int splitIndex = this.sortSplitSet(S.subList(1, S.size()), node) + 1;
        mcdl.countUp();
        List<Pair<Double, Integer>> rightS = S.subList(splitIndex + 1, S.size());
        List<Pair<Double, Integer>> leftS = S.subList(1, splitIndex + 1);
        threadpool.submit(() -> {
            node.right = this.makeVPTree(rightS, threadpool, mcdl);
            mcdl.countDown();
        });
        node.left = this.makeVPTree(leftS, threadpool, mcdl);
        return node;
    }

    private int selectVantagePointIndex(List<Pair<Double, Integer>> S) {
        int vpIndex;
        if (this.vpSelection == VPSelection.Random) {
            vpIndex = this.rand.nextInt(S.size());
        } else {
            int i;
            IntList samples = new IntList(this.sampleSize);
            if (this.sampleSize <= S.size()) {
                for (i = 0; i < this.sampleSize; ++i) {
                    samples.add(S.get(i).getSecondItem());
                }
            } else {
                for (i = 0; i < this.sampleSize; ++i) {
                    samples.add(S.get(this.rand.nextInt(S.size())).getSecondItem());
                }
            }
            double[] distances = new double[this.sampleSize];
            int bestVP = -1;
            double bestSpread = Double.NEGATIVE_INFINITY;
            for (int i2 = 0; i2 < Math.min(this.searchIterations, S.size()); ++i2) {
                int candIndx = this.searchIterations <= S.size() ? i2 : this.rand.nextInt(S.size());
                int candV = S.get(candIndx).getSecondItem();
                for (int j = 0; j < samples.size(); ++j) {
                    distances[j] = this.dm.dist(candV, (Integer)samples.get(j), this.allVecs, this.distCache);
                }
                Arrays.sort(distances);
                double median = distances[distances.length / 2];
                double spread = 0.0;
                for (double distance : distances) {
                    spread += Math.abs(distance - median);
                }
                if (!(spread > bestSpread)) continue;
                bestSpread = spread;
                bestVP = candIndx;
            }
            vpIndex = bestVP;
        }
        return vpIndex;
    }

    private int selectVantagePoint(List<Pair<Double, Integer>> S) {
        int vpIndex = this.selectVantagePointIndex(S);
        return S.get(vpIndex).getSecondItem();
    }

    @Override
    public VPTree<V> clone() {
        return new VPTree<V>(this);
    }

    private class VPLeaf
    extends TreeNode {
        IntList points;
        DoubleList bounds;

        public VPLeaf(List<Pair<Double, Integer>> points) {
            this.points = new IntList(points.size());
            this.bounds = new DoubleList(points.size());
            for (int i = 0; i < points.size(); ++i) {
                this.points.add(points.get(i).getSecondItem());
                this.bounds.add(points.get(i).getFirstItem());
            }
        }

        public VPLeaf(VPLeaf toCopy) {
            this.bounds = new DoubleList(toCopy.bounds);
            this.points = new IntList(toCopy.points);
        }

        @Override
        public void insert(int x_indx, double dist_to_parent) {
            this.points.add(x_indx);
            this.bounds.add(dist_to_parent);
        }

        @Override
        public void searchKNN(Vec query, int k, BoundedSortedList<IndexDistPair> list, double x, List<Double> qi) {
            double dist = -1.0;
            double tau = list.isEmpty() ? Double.MAX_VALUE : ((IndexDistPair)list.get(list.size() - 1)).getDist();
            for (int i = 0; i < this.points.size(); ++i) {
                double d;
                int point_i = this.points.getI(i);
                double bound_i = this.bounds.getD(i);
                if (list.size() < k) {
                    list.add(new IndexDistPair(point_i, VPTree.this.dm.dist(point_i, query, qi, VPTree.this.allVecs, VPTree.this.distCache)));
                    tau = ((IndexDistPair)list.get(list.size() - 1)).getDist();
                    continue;
                }
                if (!(bound_i - tau <= x) || !(x <= bound_i + tau)) continue;
                dist = VPTree.this.dm.dist(point_i, query, qi, VPTree.this.allVecs, VPTree.this.distCache);
                if (!(d < tau)) continue;
                list.add(new IndexDistPair(point_i, dist));
                tau = ((IndexDistPair)list.get(list.size() - 1)).getDist();
            }
        }

        @Override
        public void searchRange(Vec query, double range, List<Integer> neighbors, List<Double> distances, double x, List<Double> qi) {
            double dist = Double.MAX_VALUE;
            for (int i = 0; i < this.points.size(); ++i) {
                double d;
                int point_i = this.points.getI(i);
                double bound_i = this.bounds.getD(i);
                if (!(bound_i - range <= x) || !(x <= bound_i + range)) continue;
                dist = VPTree.this.dm.dist(point_i, query, qi, VPTree.this.allVecs, VPTree.this.distCache);
                if (!(d < range)) continue;
                neighbors.add(point_i);
                distances.add(dist);
            }
        }

        @Override
        public TreeNode clone() {
            return new VPLeaf(this);
        }
    }

    private class VPNode
    extends TreeNode {
        int p;
        double left_low;
        double left_high;
        double right_low;
        double right_high;
        TreeNode right;
        TreeNode left;

        public VPNode(int p) {
            this.p = p;
        }

        public VPNode(VPNode toCopy) {
            this(toCopy.p);
            this.left_low = toCopy.left_low;
            this.left_high = toCopy.left_high;
            this.right_low = toCopy.right_low;
            this.right_high = toCopy.right_high;
            this.left = vPTree.cloneChangeContext(toCopy.left);
            this.right = vPTree.cloneChangeContext(toCopy.right);
        }

        @Override
        public void insert(int x_indx, double dist_to_parent) {
            TreeNode child;
            double dist = VPTree.this.dm.dist(this.p, x_indx, (List<? extends Vec>)VPTree.this.allVecs, (List<Double>)VPTree.this.distCache);
            if (dist * 2.0 < this.left_high + this.right_low) {
                this.left_high = Math.max(this.left_high, dist);
                this.left_low = Math.min(this.left_low, dist);
                child = this.left = this.maybeExpandChild(this.left);
            } else {
                this.right_high = Math.max(this.right_high, dist);
                this.right_low = Math.min(this.right_low, dist);
                child = this.right = this.maybeExpandChild(this.right);
            }
            child.insert(x_indx, dist);
        }

        private TreeNode maybeExpandChild(TreeNode child) {
            if (child instanceof VPLeaf) {
                IntList childs_children = ((VPLeaf)child).points;
                if (childs_children.size() <= VPTree.this.maxLeafSize * VPTree.this.maxLeafSize) {
                    return child;
                }
                ArrayList<Pair<Double, Integer>> S = new ArrayList<Pair<Double, Integer>>(childs_children.size());
                Iterator iterator = childs_children.iterator();
                while (iterator.hasNext()) {
                    int indx = (Integer)iterator.next();
                    S.add(new Pair<Double, Integer>((Double)Double.MAX_VALUE, indx));
                }
                int vpIndex = VPTree.this.selectVantagePointIndex(S);
                VPNode node = new VPNode((Integer)((Pair)S.get(vpIndex)).getSecondItem());
                Collections.swap(S, 0, vpIndex);
                int splitIndex = VPTree.this.sortSplitSet(S.subList(1, S.size()), node) + 1;
                node.right = new VPLeaf(S.subList(splitIndex + 1, S.size()));
                node.left = new VPLeaf(S.subList(1, splitIndex + 1));
                return node;
            }
            return child;
        }

        private boolean searchInLeft(double x, double tau) {
            if (this.left == null) {
                return false;
            }
            return this.left_low - tau <= x && x <= this.left_high + tau;
        }

        private boolean searchInRight(double x, double tau) {
            if (this.right == null) {
                return false;
            }
            return this.right_low - tau <= x && x <= this.right_high + tau;
        }

        @Override
        public void searchKNN(Vec query, int k, BoundedSortedList<IndexDistPair> list, double x, List<Double> qi) {
            x = VPTree.this.dm.dist(this.p, query, qi, VPTree.this.allVecs, VPTree.this.distCache);
            if (list.size() < k || x < ((IndexDistPair)list.get(k - 1)).getDist()) {
                list.add(new IndexDistPair(this.p, x));
            }
            double tau = ((IndexDistPair)list.get(list.size() - 1)).getDist();
            double middle = (this.left_high + this.right_low) * 0.5;
            if (x < middle) {
                if (this.searchInLeft(x, tau) || list.size() < k) {
                    this.left.searchKNN(query, k, list, x, qi);
                }
                if (this.searchInRight(x, tau = ((IndexDistPair)list.get(list.size() - 1)).getDist()) || list.size() < k) {
                    this.right.searchKNN(query, k, list, x, qi);
                }
            } else {
                if (this.searchInRight(x, tau) || list.size() < k) {
                    this.right.searchKNN(query, k, list, x, qi);
                }
                if (this.searchInLeft(x, tau = ((IndexDistPair)list.get(list.size() - 1)).getDist()) || list.size() < k) {
                    this.left.searchKNN(query, k, list, x, qi);
                }
            }
        }

        @Override
        public void searchRange(Vec query, double range, List<Integer> neighbors, List<Double> distances, double x, List<Double> qi) {
            x = VPTree.this.dm.dist(this.p, query, qi, VPTree.this.allVecs, VPTree.this.distCache);
            if (x <= range) {
                neighbors.add(this.p);
                distances.add(x);
            }
            if (this.searchInLeft(x, range)) {
                this.left.searchRange(query, range, neighbors, distances, x, qi);
            }
            if (this.searchInRight(x, range)) {
                this.right.searchRange(query, range, neighbors, distances, x, qi);
            }
        }

        @Override
        public TreeNode clone() {
            return new VPNode(this);
        }
    }

    private abstract class TreeNode
    implements Cloneable,
    Serializable {
        private TreeNode() {
        }

        public abstract void insert(int var1, double var2);

        public abstract void searchKNN(Vec var1, int var2, BoundedSortedList<IndexDistPair> var3, double var4, List<Double> var6);

        public abstract void searchRange(Vec var1, double var2, List<Integer> var4, List<Double> var5, double var6, List<Double> var8);

        public abstract TreeNode clone();
    }

    public static enum VPSelection {
        Sampling,
        Random;

    }
}

