/*
 * Decompiled with CFR 0.152.
 */
package com.jujutsu.tsne;

import com.jujutsu.tsne.PrincipalComponentAnalysis;
import com.jujutsu.tsne.TSne;
import com.jujutsu.tsne.barneshut.TSneConfiguration;
import com.jujutsu.utils.MatrixOps;

public class SimpleTSne
implements TSne {
    MatrixOps mo = new MatrixOps();
    protected volatile boolean abort = false;

    @Override
    public double[][] tsne(TSneConfiguration config) {
        double[][] X = config.getXin();
        int no_dims = config.getOutputDims();
        int initial_dims = config.getInitialDims();
        double perplexity = config.getPerplexity();
        int max_iter = config.getMaxIter();
        boolean use_pca = config.usePca();
        String IMPLEMENTATION_NAME = this.getClass().getSimpleName();
        System.out.println("X:Shape is = " + X.length + " x " + X[0].length);
        System.out.println("Running " + IMPLEMENTATION_NAME + ".");
        if (use_pca && X[0].length > initial_dims && initial_dims > 0) {
            PrincipalComponentAnalysis pca = new PrincipalComponentAnalysis();
            X = pca.pca(X, initial_dims);
            System.out.println("X:Shape after PCA is = " + X.length + " x " + X[0].length);
        }
        int n = X.length;
        double momentum = 0.5;
        double initial_momentum = 0.5;
        double final_momentum = 0.8;
        int eta = 500;
        double min_gain = 0.01;
        double[][] Y = MatrixOps.rnorm(n, no_dims);
        double[][] dY = MatrixOps.fillMatrix(n, no_dims, 0.0);
        double[][] iY = MatrixOps.fillMatrix(n, no_dims, 0.0);
        double[][] gains = MatrixOps.fillMatrix(n, no_dims, 1.0);
        double[][] P = this.x2p((double[][])X, (double)1.0E-5, (double)perplexity).P;
        P = MatrixOps.plus(P, this.mo.transpose(P));
        P = MatrixOps.scalarDivide(P, MatrixOps.sum(P));
        P = MatrixOps.scalarMult(P, 4.0);
        P = MatrixOps.maximum(P, 1.0E-12);
        System.out.println("Y:Shape is = " + Y.length + " x " + Y[0].length);
        for (int iter = 0; iter < max_iter && !this.abort; ++iter) {
            double[][] sum_Y = this.mo.transpose(MatrixOps.sum(MatrixOps.square(Y), 1));
            double[][] num = MatrixOps.scalarInverse(MatrixOps.scalarPlus(MatrixOps.addRowVector(this.mo.transpose(MatrixOps.addRowVector(MatrixOps.scalarMult(MatrixOps.times(Y, this.mo.transpose(Y)), -2.0), sum_Y)), sum_Y), 1.0));
            MatrixOps.assignAtIndex(num, MatrixOps.range(n), MatrixOps.range(n), 0.0);
            double[][] Q = MatrixOps.scalarDivide(num, MatrixOps.sum(num));
            Q = MatrixOps.maximum(Q, 1.0E-12);
            double[][] L = this.mo.scalarMultiply(this.mo.minus(P, Q), num);
            dY = MatrixOps.scalarMult(MatrixOps.times(this.mo.minus(MatrixOps.diag(MatrixOps.sum(L, 1)), L), Y), 4.0);
            momentum = iter < 20 ? initial_momentum : final_momentum;
            gains = MatrixOps.plus(this.mo.scalarMultiply(MatrixOps.scalarPlus(gains, 0.2), MatrixOps.abs(MatrixOps.negate(MatrixOps.equal(MatrixOps.biggerThan(dY, 0.0), MatrixOps.biggerThan(iY, 0.0))))), this.mo.scalarMultiply(MatrixOps.scalarMult(gains, 0.8), MatrixOps.abs(MatrixOps.equal(MatrixOps.biggerThan(dY, 0.0), MatrixOps.biggerThan(iY, 0.0)))));
            MatrixOps.assignAllLessThan(gains, min_gain, min_gain);
            iY = this.mo.minus(MatrixOps.scalarMult(iY, momentum), MatrixOps.scalarMult(this.mo.scalarMultiply(gains, dY), eta));
            Y = MatrixOps.plus(Y, iY);
            Y = this.mo.minus(Y, MatrixOps.tile(MatrixOps.mean(Y, 0), n, 1));
            if (iter % 100 == 0) {
                double[][] logdivide = MatrixOps.log(MatrixOps.scalarDivide(P, Q));
                logdivide = MatrixOps.replaceNaN(logdivide, 0.0);
                double C = MatrixOps.sum(this.mo.scalarMultiply(P, logdivide));
                System.out.println("Iteration " + (iter + 1) + ": error is " + C);
            } else if ((iter + 1) % 10 == 0) {
                System.out.println("Iteration " + (iter + 1));
            }
            if (iter != 100) continue;
            P = MatrixOps.scalarDivide(P, 4.0);
        }
        return Y;
    }

    public TSne.R Hbeta(double[][] D, double beta) {
        double[][] P = MatrixOps.exp(MatrixOps.scalarMult(MatrixOps.scalarMult(D, beta), -1.0));
        double sumP = MatrixOps.sum(P);
        double H = Math.log(sumP) + beta * MatrixOps.sum(this.mo.scalarMultiply(D, P)) / sumP;
        P = MatrixOps.scalarDivide(P, sumP);
        TSne.R r = new TSne.R();
        r.H = H;
        r.P = P;
        return r;
    }

    public TSne.R x2p(double[][] X, double tol, double perplexity) {
        int n = X.length;
        double[][] sum_X = MatrixOps.sum(MatrixOps.square(X), 1);
        double[][] times = MatrixOps.scalarMult(MatrixOps.times(X, this.mo.transpose(X)), -2.0);
        double[][] prodSum = MatrixOps.addColumnVector(this.mo.transpose(times), sum_X);
        double[][] D = MatrixOps.addRowVector(prodSum, this.mo.transpose(sum_X));
        double[][] P = MatrixOps.fillMatrix(n, n, 0.0);
        double[] beta = MatrixOps.fillMatrix(n, n, 1.0)[0];
        double logU = Math.log(perplexity);
        System.out.println("Starting x2p...");
        for (int i = 0; i < n; ++i) {
            if (i % 500 == 0) {
                System.out.println("Computing P-values for point " + i + " of " + n + "...");
            }
            double betamin = Double.NEGATIVE_INFINITY;
            double betamax = Double.POSITIVE_INFINITY;
            double[][] Di = MatrixOps.getValuesFromRow(D, i, MatrixOps.concatenate(MatrixOps.range(0, i), MatrixOps.range(i + 1, n)));
            TSne.R hbeta = this.Hbeta(Di, beta[i]);
            double H = hbeta.H;
            double[][] thisP = hbeta.P;
            double Hdiff = H - logU;
            for (int tries = 0; Math.abs(Hdiff) > tol && tries < 50; ++tries) {
                if (Hdiff > 0.0) {
                    betamin = beta[i];
                    beta[i] = Double.isInfinite(betamax) ? beta[i] * 2.0 : (beta[i] + betamax) / 2.0;
                } else {
                    betamax = beta[i];
                    beta[i] = Double.isInfinite(betamin) ? beta[i] / 2.0 : (beta[i] + betamin) / 2.0;
                }
                hbeta = this.Hbeta(Di, beta[i]);
                H = hbeta.H;
                thisP = hbeta.P;
                Hdiff = H - logU;
            }
            MatrixOps.assignValuesToRow(P, i, MatrixOps.concatenate(MatrixOps.range(0, i), MatrixOps.range(i + 1, n)), thisP[0]);
        }
        TSne.R r = new TSne.R();
        r.P = P;
        r.beta = beta;
        double sigma = MatrixOps.mean(MatrixOps.sqrt(MatrixOps.scalarInverse(beta)));
        System.out.println("Mean value of sigma: " + sigma);
        return r;
    }

    @Override
    public void abort() {
        this.abort = true;
    }
}

