/*
 * Decompiled with CFR 0.152.
 */
package com.jujutsu.tsne.barneshut;

import com.jujutsu.tsne.barneshut.AtomicDouble;
import com.jujutsu.tsne.barneshut.BHTSne;
import com.jujutsu.tsne.barneshut.DataPoint;
import com.jujutsu.tsne.barneshut.ParallelSPTree;
import com.jujutsu.tsne.barneshut.ParallelVpTree;
import com.jujutsu.tsne.barneshut.SPTree;
import com.jujutsu.tsne.barneshut.TSneConfiguration;
import com.jujutsu.utils.MatrixOps;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.Future;
import java.util.concurrent.RecursiveAction;

public class ParallelBHTsne
extends BHTSne {
    private ForkJoinPool gradientPool;
    private ExecutorService gradientCalculationPool;

    @Override
    double[][] run(TSneConfiguration config) {
        this.gradientPool = new ForkJoinPool(Runtime.getRuntime().availableProcessors());
        this.gradientCalculationPool = Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors());
        double[][] Y = super.run(config);
        this.gradientPool.shutdown();
        this.gradientCalculationPool.shutdown();
        return Y;
    }

    @Override
    void updateGradient(int N, int no_dims, double[] Y, double momentum, double eta, double[] dY, double[] uY, double[] gains) {
        RecursiveGradientUpdater dslr = new RecursiveGradientUpdater(N, no_dims, Y, momentum, eta, dY, uY, gains, 0, N * no_dims, N / (Runtime.getRuntime().availableProcessors() * 10));
        this.gradientPool.invoke(dslr);
    }

    @Override
    void computeGradient(double[] P, int[] inp_row_P, int[] inp_col_P, double[] inp_val_P, double[] Y, int N, int D, double[] dC, double theta) {
        ParallelSPTree tree = new ParallelSPTree(D, Y, N);
        double[] pos_f = new double[N * D];
        double[][] neg_f = new double[N][D];
        tree.computeEdgeForces(inp_row_P, inp_col_P, inp_val_P, N, pos_f);
        double sum_Q = 0.0;
        ArrayList<ParallelGradientCalculator> calculators = new ArrayList<ParallelGradientCalculator>();
        for (int n = 0; n < N; ++n) {
            calculators.add(new ParallelGradientCalculator(tree, neg_f, theta, n, 20));
        }
        try {
            List results = this.gradientCalculationPool.invokeAll(calculators);
            for (Future result : results) {
                double tmp = (Double)result.get();
                sum_Q += tmp;
            }
        }
        catch (InterruptedException e) {
            e.printStackTrace();
            System.exit(-1);
        }
        catch (ExecutionException e) {
            e.printStackTrace();
            System.exit(-1);
        }
        for (int i = 0; i < N; ++i) {
            for (int j = 0; j < D; ++j) {
                dC[i * D + j] = pos_f[i * D + j] - neg_f[i][j] / sum_Q;
            }
        }
    }

    @Override
    void computeGaussianPerplexity(double[] X, int N, int D, int[] _row_P, int[] _col_P, double[] _val_P, double perplexity, int K) {
        if (perplexity > (double)K) {
            System.out.println("Perplexity should be lower than K!");
        }
        int[] row_P = _row_P;
        int[] col_P = _col_P;
        double[] val_P = _val_P;
        double[] cur_P = new double[N - 1];
        row_P[0] = 0;
        for (int n = 0; n < N; ++n) {
            row_P[n + 1] = row_P[n] + K;
        }
        ParallelVpTree tree = new ParallelVpTree(this.gradientPool, this.distance);
        DataPoint[] obj_X = new DataPoint[N];
        for (int n = 0; n < N; ++n) {
            double[] row = MatrixOps.extractRowFromFlatMatrix(X, n, D);
            obj_X[n] = new DataPoint(D, n, row);
        }
        tree.create(obj_X);
        List<Future<ParallelVpTree.TreeSearchResult>> results = tree.searchMultiple(tree, obj_X, K + 1);
        for (Future<ParallelVpTree.TreeSearchResult> result : results) {
            ParallelVpTree.ParallelTreeNode.TreeSearchResult res = null;
            List<Double> distances = null;
            List<DataPoint> indices = null;
            int n = -1;
            try {
                res = (ParallelVpTree.ParallelTreeNode.TreeSearchResult)result.get();
                distances = res.getDistances();
                indices = res.getIndices();
                n = res.getIndex();
            }
            catch (InterruptedException | ExecutionException e) {
                e.printStackTrace();
            }
            boolean found = false;
            double beta = 1.0;
            double min_beta = -1.7976931348623157E308;
            double max_beta = Double.MAX_VALUE;
            double tol = 1.0E-5;
            double sum_P = 0.0;
            for (int iter = 0; !found && iter < 200; ++iter) {
                sum_P = Double.MIN_VALUE;
                double H = 0.0;
                for (int m = 0; m < K; ++m) {
                    cur_P[m] = Math.exp(-beta * distances.get(m + 1));
                    sum_P += cur_P[m];
                    H += beta * (distances.get(m + 1) * cur_P[m]);
                }
                double Hdiff = (H = H / sum_P + Math.log(sum_P)) - Math.log(perplexity);
                if (Hdiff < tol && -Hdiff < tol) {
                    found = true;
                    continue;
                }
                if (Hdiff > 0.0) {
                    min_beta = beta;
                    if (max_beta == Double.MAX_VALUE || max_beta == -1.7976931348623157E308) {
                        beta *= 2.0;
                        continue;
                    }
                    beta = (beta + max_beta) / 2.0;
                    continue;
                }
                max_beta = beta;
                if (min_beta == -1.7976931348623157E308 || min_beta == Double.MAX_VALUE) {
                    beta /= 2.0;
                    continue;
                }
                beta = (beta + min_beta) / 2.0;
            }
            for (int m = 0; m < K; ++m) {
                int n2 = m;
                cur_P[n2] = cur_P[n2] / sum_P;
                col_P[row_P[n] + m] = indices.get(m + 1).index();
                val_P[row_P[n] + m] = cur_P[m];
            }
        }
    }

    class RecursiveGradientUpdater
    extends RecursiveAction {
        static final long serialVersionUID = 1L;
        int startIdx = -1;
        int endIdx = -1;
        int limit = 100;
        int N;
        int no_dims;
        double[] Y;
        double momentum;
        double eta;
        double[] dY;
        double[] uY;
        double[] gains;

        public RecursiveGradientUpdater(int n, int no_dims, double[] Y, double momentum, double eta, double[] dY, double[] uY, double[] gains, int startIdx, int endIdx, int limit) {
            this.startIdx = startIdx;
            this.endIdx = endIdx;
            this.limit = limit;
            this.N = n;
            this.no_dims = no_dims;
            this.Y = Y;
            this.momentum = momentum;
            this.eta = eta;
            this.dY = dY;
            this.uY = uY;
            this.gains = gains;
        }

        @Override
        protected void compute() {
            if (this.endIdx - this.startIdx <= this.limit) {
                for (int n = this.startIdx; n < this.endIdx; ++n) {
                    double d = this.gains[n] = BHTSne.sign_tsne(this.dY[n]) != BHTSne.sign_tsne(this.uY[n]) ? this.gains[n] + 0.2 : this.gains[n] * 0.8;
                    if (this.gains[n] < 0.01) {
                        this.gains[n] = 0.01;
                    }
                    this.Y[n] = this.Y[n] + this.uY[n];
                    this.uY[n] = this.momentum * this.uY[n] - this.eta * this.gains[n] * this.dY[n];
                }
            } else {
                int range = this.endIdx - this.startIdx;
                int startIdx1 = this.startIdx;
                int endIdx1 = this.startIdx + range / 2;
                int endIdx2 = this.endIdx;
                RecursiveGradientUpdater.invokeAll(new RecursiveGradientUpdater(this.N, this.no_dims, this.Y, this.momentum, this.eta, this.dY, this.uY, this.gains, startIdx1, endIdx1, this.limit), new RecursiveGradientUpdater(this.N, this.no_dims, this.Y, this.momentum, this.eta, this.dY, this.uY, this.gains, endIdx1, endIdx2, this.limit));
            }
        }
    }

    class ParallelGradientCalculator
    implements Callable<Double> {
        static final long serialVersionUID = 1L;
        int row = -1;
        int limit = 100;
        ParallelSPTree tree;
        double[][] neg_f;
        double theta;

        public ParallelGradientCalculator(ParallelSPTree tree, double[][] neg_f, double theta, int row, int ll) {
            this.limit = ll;
            this.row = row;
            this.tree = tree;
            this.neg_f = neg_f;
            this.theta = theta;
        }

        @Override
        public Double call() {
            return this.tree.computeNonEdgeForces(this.row, this.theta, this.neg_f[this.row], 0.0);
        }
    }

    class RecursiveGradientCalculator
    extends RecursiveAction {
        static final long serialVersionUID = 1L;
        int startRow = -1;
        int endRow = -1;
        int limit = 100;
        SPTree tree;
        double[][] neg_f;
        double theta;
        AtomicDouble sum_Q;

        public RecursiveGradientCalculator(SPTree tree, double[][] neg_f, double theta, AtomicDouble sum_Q, int startRow, int endRow, int ll) {
            this.limit = ll;
            this.startRow = startRow;
            this.endRow = endRow;
            this.tree = tree;
            this.neg_f = neg_f;
            this.theta = theta;
            this.sum_Q = sum_Q;
        }

        @Override
        protected void compute() {
            if (this.endRow - this.startRow <= this.limit) {
                for (int row = this.startRow; row < this.endRow; ++row) {
                    this.tree.computeNonEdgeForces(row, this.theta, this.neg_f[row], this.sum_Q);
                }
            } else {
                int endDoc1;
                int range = this.endRow - this.startRow;
                int startDoc1 = this.startRow;
                int startDoc2 = endDoc1 = this.startRow + range / 2;
                int endDoc2 = this.endRow;
                RecursiveGradientCalculator.invokeAll(new RecursiveGradientCalculator(this.tree, this.neg_f, this.theta, this.sum_Q, startDoc1, endDoc1, this.limit), new RecursiveGradientCalculator(this.tree, this.neg_f, this.theta, this.sum_Q, startDoc2, endDoc2, this.limit));
            }
        }
    }
}

