/*
 * Decompiled with CFR 0.152.
 */
package com.jujutsu.tsne;

import com.jujutsu.tsne.PrincipalComponentAnalysis;
import com.jujutsu.tsne.TSne;
import com.jujutsu.tsne.barneshut.TSneConfiguration;
import com.jujutsu.utils.BlasOps;
import com.jujutsu.utils.MatrixOps;
import org.jblas.DoubleMatrix;

public class BlasTSne
implements TSne {
    MatrixOps mo = new MatrixOps();
    protected volatile boolean abort = false;

    @Override
    public double[][] tsne(TSneConfiguration config) {
        double[][] X = config.getXin();
        int no_dims = config.getOutputDims();
        int initial_dims = config.getInitialDims();
        double perplexity = config.getPerplexity();
        int max_iter = config.getMaxIter();
        boolean use_pca = config.usePca();
        String IMPLEMENTATION_NAME = this.getClass().getSimpleName();
        System.out.println("X:Shape is = " + X.length + " x " + X[0].length);
        System.out.println("Running " + IMPLEMENTATION_NAME + ".");
        if (use_pca && X[0].length > initial_dims && initial_dims > 0) {
            PrincipalComponentAnalysis pca = new PrincipalComponentAnalysis();
            X = pca.pca(X, initial_dims);
            System.out.println("X:Shape after PCA is = " + X.length + " x " + X[0].length);
        }
        int n = X.length;
        double momentum = 0.5;
        double initial_momentum = 0.5;
        double final_momentum = 0.8;
        int eta = 500;
        double min_gain = 0.01;
        DoubleMatrix Y = DoubleMatrix.randn((int)n, (int)no_dims);
        DoubleMatrix dY = DoubleMatrix.zeros((int)n, (int)no_dims);
        DoubleMatrix iY = DoubleMatrix.zeros((int)n, (int)no_dims);
        DoubleMatrix gains = DoubleMatrix.ones((int)n, (int)no_dims);
        double[][] Pt = this.x2p((double[][])X, (double)1.0E-5, (double)perplexity).P;
        DoubleMatrix P = new DoubleMatrix(Pt);
        P = P.add(P.transpose());
        P = P.div(P.sum());
        P = P.mul(4.0);
        P = P.max(1.0E-12);
        System.out.println("Y:Shape is = " + Y.rows + " x " + Y.columns);
        for (int iter = 0; iter < max_iter && !this.abort; ++iter) {
            DoubleMatrix sum_Y = BlasOps.square(Y).rowSums().transpose();
            DoubleMatrix num = BlasOps.scalarInverse(Y.mmul(Y.transpose()).mul(-2.0).addRowVector(sum_Y).transpose().addRowVector(sum_Y).add(1.0));
            BlasOps.assignAtIndex(num, MatrixOps.range(n), MatrixOps.range(n), 0.0);
            DoubleMatrix Q = num.div(num.sum());
            Q = Q.max(1.0E-12);
            DoubleMatrix L = P.sub(Q).mul(num);
            dY = DoubleMatrix.diag((DoubleMatrix)L.rowSums()).sub(L).mmul(Y).mul(4.0);
            momentum = iter < 20 ? initial_momentum : final_momentum;
            DoubleMatrix gainsSmall = new DoubleMatrix();
            gainsSmall.copy(gains);
            DoubleMatrix gainsBig = new DoubleMatrix();
            gainsBig.copy(gains);
            gainsSmall = gainsSmall.add(0.2);
            gainsBig = gainsBig.mul(0.8);
            DoubleMatrix btNeg = BlasOps.abs(MatrixOps.negate(MatrixOps.equal(BlasOps.biggerThan(dY, 0.0), BlasOps.biggerThan(iY, 0.0))));
            gainsSmall = gainsSmall.mul(btNeg);
            DoubleMatrix bt = BlasOps.abs(MatrixOps.equal(BlasOps.biggerThan(dY, 0.0), BlasOps.biggerThan(iY, 0.0)));
            gainsBig = gainsBig.mul(bt);
            gains = gainsSmall.add(gainsBig);
            BlasOps.assignAllLessThan(gains, min_gain, min_gain);
            iY = iY.mul(momentum).sub(gains.mul(dY).mul((double)eta));
            Y = Y.add(iY);
            Y = Y.sub(BlasOps.tile(Y.columnMeans(), n, 1));
            if (iter % 100 == 0) {
                DoubleMatrix logdivide = BlasOps.log(P.div(Q));
                logdivide = BlasOps.replaceNaN(logdivide, 0.0);
                double C = P.mul(logdivide).sum();
                System.out.println("Iteration " + iter + ": error is " + C);
            } else if (iter % 10 == 0) {
                System.out.println("Iteration " + iter);
            }
            if (iter != 100) continue;
            P = P.div(4.0);
        }
        return Y.toArray2();
    }

    public TSne.R Hbeta(double[][] D, double beta) {
        DoubleMatrix Dd = new DoubleMatrix(D);
        DoubleMatrix P = BlasOps.exp(Dd.mul(-beta));
        double sumP = P.sum();
        double H = Math.log(sumP) + beta * Dd.mul(P).sum() / sumP;
        P = P.div(sumP);
        TSne.R r = new TSne.R();
        r.H = H;
        r.P = P.toArray2();
        return r;
    }

    public TSne.R x2p(double[][] X, double tol, double perplexity) {
        int n = X.length;
        double[][] sum_X = MatrixOps.sum(MatrixOps.square(X), 1);
        double[][] times = MatrixOps.scalarMult(MatrixOps.times(X, this.mo.transpose(X)), -2.0);
        double[][] prodSum = MatrixOps.addColumnVector(this.mo.transpose(times), sum_X);
        double[][] D = MatrixOps.addRowVector(prodSum, this.mo.transpose(sum_X));
        double[][] P = MatrixOps.fillMatrix(n, n, 0.0);
        double[] beta = MatrixOps.fillMatrix(n, n, 1.0)[0];
        double logU = Math.log(perplexity);
        System.out.println("Starting x2p...");
        for (int i = 0; i < n; ++i) {
            if (i % 500 == 0) {
                System.out.println("Computing P-values for point " + i + " of " + n + "...");
            }
            double betamin = Double.NEGATIVE_INFINITY;
            double betamax = Double.POSITIVE_INFINITY;
            double[][] Di = MatrixOps.getValuesFromRow(D, i, MatrixOps.concatenate(MatrixOps.range(0, i), MatrixOps.range(i + 1, n)));
            TSne.R hbeta = this.Hbeta(Di, beta[i]);
            double H = hbeta.H;
            double[][] thisP = hbeta.P;
            double Hdiff = H - logU;
            for (int tries = 0; Math.abs(Hdiff) > tol && tries < 50; ++tries) {
                if (Hdiff > 0.0) {
                    betamin = beta[i];
                    beta[i] = Double.isInfinite(betamax) ? beta[i] * 2.0 : (beta[i] + betamax) / 2.0;
                } else {
                    betamax = beta[i];
                    beta[i] = Double.isInfinite(betamin) ? beta[i] / 2.0 : (beta[i] + betamin) / 2.0;
                }
                hbeta = this.Hbeta(Di, beta[i]);
                H = hbeta.H;
                thisP = hbeta.P;
                Hdiff = H - logU;
            }
            MatrixOps.assignValuesToRow(P, i, MatrixOps.concatenate(MatrixOps.range(0, i), MatrixOps.range(i + 1, n)), thisP[0]);
        }
        TSne.R r = new TSne.R();
        r.P = P;
        r.beta = beta;
        double sigma = MatrixOps.mean(MatrixOps.sqrt(MatrixOps.scalarInverse(beta)));
        System.out.println("Mean value of sigma: " + sigma);
        return r;
    }

    @Override
    public void abort() {
        this.abort = true;
    }
}

