/*
 * Decompiled with CFR 0.152.
 */
package com.jujutsu.tsne;

import com.jujutsu.tsne.FastTSne;
import com.jujutsu.tsne.PrincipalComponentAnalysis;
import com.jujutsu.tsne.barneshut.TSneConfiguration;
import com.jujutsu.utils.EjmlOps;
import com.jujutsu.utils.MatrixOps;
import org.ejml.data.D1Matrix64F;
import org.ejml.data.DenseMatrix64F;
import org.ejml.data.RowD1Matrix64F;
import org.ejml.ops.CommonOps;

public class MemOptimizedTSne
extends FastTSne {
    @Override
    public double[][] tsne(TSneConfiguration config) {
        double[][] X = config.getXin();
        int no_dims = config.getOutputDims();
        int initial_dims = config.getInitialDims();
        double perplexity = config.getPerplexity();
        int max_iter = config.getMaxIter();
        boolean use_pca = config.usePca();
        String IMPLEMENTATION_NAME = this.getClass().getSimpleName();
        System.out.println("X:Shape is = " + X.length + " x " + X[0].length);
        System.out.println("Running " + IMPLEMENTATION_NAME + ".");
        long end = System.currentTimeMillis();
        long start = System.currentTimeMillis();
        if (use_pca && X[0].length > initial_dims && initial_dims > 0) {
            PrincipalComponentAnalysis pca = new PrincipalComponentAnalysis();
            X = pca.pca(X, initial_dims);
            System.out.println("X:Shape after PCA is = " + X.length + " x " + X[0].length);
            System.out.println(MatrixOps.doubleArrayToPrintString(X, 10, 10));
        }
        int n = X.length;
        double momentum = 0.5;
        double initial_momentum = 0.5;
        double final_momentum = 0.8;
        int eta = 500;
        double min_gain = 0.01;
        DenseMatrix64F Y = new DenseMatrix64F(MatrixOps.rnorm(n, no_dims));
        DenseMatrix64F Ysqlmul = new DenseMatrix64F(Y.numRows, Y.numRows);
        DenseMatrix64F dY = new DenseMatrix64F(MatrixOps.fillMatrix(n, no_dims, 0.0));
        DenseMatrix64F iY = new DenseMatrix64F(MatrixOps.fillMatrix(n, no_dims, 0.0));
        DenseMatrix64F gains = new DenseMatrix64F(MatrixOps.fillMatrix(n, no_dims, 1.0));
        DenseMatrix64F btNeg = new DenseMatrix64F(n, no_dims);
        DenseMatrix64F bt = new DenseMatrix64F(n, no_dims);
        DenseMatrix64F P = new DenseMatrix64F(this.x2p((double[][])X, (double)1.0E-5, (double)perplexity).P);
        DenseMatrix64F Psized = new DenseMatrix64F(P.numRows, P.numCols);
        DenseMatrix64F diag = new DenseMatrix64F(MatrixOps.fillMatrix(Psized.numRows, Psized.numCols, 0.0));
        CommonOps.transpose((DenseMatrix64F)P, (DenseMatrix64F)Psized);
        CommonOps.addEquals((D1Matrix64F)P, (D1Matrix64F)Psized);
        CommonOps.divide((D1Matrix64F)P, (double)CommonOps.elementSum((D1Matrix64F)P));
        EjmlOps.replaceNaN(P, Double.MIN_VALUE);
        CommonOps.scale((double)4.0, (D1Matrix64F)P);
        EjmlOps.maximize(P, 1.0E-12);
        System.out.println("Using perplexity: " + perplexity);
        System.out.println("Y:Shape is = " + Y.getNumRows() + " x " + Y.getNumCols());
        DenseMatrix64F sqed = new DenseMatrix64F(Y.numRows, Y.numCols);
        DenseMatrix64F sum_Y = new DenseMatrix64F(1, Y.numRows);
        DenseMatrix64F Q = new DenseMatrix64F(P.numRows, P.numCols);
        for (int iter = 0; iter < max_iter; ++iter) {
            CommonOps.elementPower((D1Matrix64F)Y, (double)2.0, (D1Matrix64F)sqed);
            CommonOps.sumRows((DenseMatrix64F)sqed, (DenseMatrix64F)sum_Y);
            CommonOps.multAddTransB((double)-2.0, (RowD1Matrix64F)Y, (RowD1Matrix64F)Y, (RowD1Matrix64F)Ysqlmul);
            EjmlOps.addRowVector(Ysqlmul, sum_Y);
            CommonOps.transpose((DenseMatrix64F)Ysqlmul);
            EjmlOps.addRowVector(Ysqlmul, sum_Y);
            CommonOps.add((D1Matrix64F)Ysqlmul, (double)1.0);
            CommonOps.divide((double)1.0, (D1Matrix64F)Ysqlmul);
            EjmlOps.assignAtIndex(Ysqlmul, MatrixOps.range(n), MatrixOps.range(n), 0.0);
            CommonOps.divide((D1Matrix64F)Ysqlmul, (double)CommonOps.elementSum((D1Matrix64F)Ysqlmul), (D1Matrix64F)Q);
            EjmlOps.maximize(Q, 1.0E-12);
            CommonOps.subtract((D1Matrix64F)P, (D1Matrix64F)Q, (D1Matrix64F)Psized);
            CommonOps.elementMult((D1Matrix64F)Psized, (D1Matrix64F)Ysqlmul);
            DenseMatrix64F rowsum = CommonOps.sumRows((DenseMatrix64F)Psized, null);
            double[] rsum = new double[rowsum.numRows];
            for (int i = 0; i < rsum.length; ++i) {
                rsum[i] = rowsum.get(i, 0);
            }
            EjmlOps.setDiag(diag, rsum);
            CommonOps.subtract((D1Matrix64F)diag, (D1Matrix64F)Psized, (D1Matrix64F)Psized);
            CommonOps.mult((RowD1Matrix64F)Psized, (RowD1Matrix64F)Y, (RowD1Matrix64F)dY);
            CommonOps.scale((double)4.0, (D1Matrix64F)dY);
            momentum = iter < 20 ? initial_momentum : final_momentum;
            boolean[][] boolMtrx = MatrixOps.equal(EjmlOps.biggerThan(dY, 0.0), EjmlOps.biggerThan(iY, 0.0));
            EjmlOps.setData(btNeg, MatrixOps.abs(MatrixOps.negate(boolMtrx)));
            EjmlOps.setData(bt, MatrixOps.abs(boolMtrx));
            DenseMatrix64F gainsSmall = new DenseMatrix64F(gains);
            DenseMatrix64F gainsBig = new DenseMatrix64F(gains);
            CommonOps.add((D1Matrix64F)gainsSmall, (double)0.2);
            CommonOps.scale((double)0.8, (D1Matrix64F)gainsBig);
            CommonOps.elementMult((D1Matrix64F)gainsSmall, (D1Matrix64F)btNeg);
            CommonOps.elementMult((D1Matrix64F)gainsBig, (D1Matrix64F)bt);
            CommonOps.add((D1Matrix64F)gainsSmall, (D1Matrix64F)gainsBig, (D1Matrix64F)gains);
            EjmlOps.assignAllLessThan(gains, min_gain, min_gain);
            CommonOps.scale((double)momentum, (D1Matrix64F)iY);
            DenseMatrix64F gainsdY = new DenseMatrix64F(gains.numRows, dY.numCols);
            CommonOps.elementMult((D1Matrix64F)gains, (D1Matrix64F)dY, (D1Matrix64F)gainsdY);
            CommonOps.scale((double)eta, (D1Matrix64F)gainsdY);
            CommonOps.subtractEquals((D1Matrix64F)iY, (D1Matrix64F)gainsdY);
            CommonOps.addEquals((D1Matrix64F)Y, (D1Matrix64F)iY);
            DenseMatrix64F colMeanY = EjmlOps.colMean(Y, 0);
            DenseMatrix64F meanTile = EjmlOps.tile(colMeanY, n, 1);
            CommonOps.subtractEquals((D1Matrix64F)Y, (D1Matrix64F)meanTile);
            if (iter % 50 == 0) {
                DenseMatrix64F Pdiv = new DenseMatrix64F(P);
                CommonOps.elementDiv((D1Matrix64F)Pdiv, (D1Matrix64F)Q);
                CommonOps.elementLog((D1Matrix64F)Pdiv, (D1Matrix64F)Psized);
                EjmlOps.replaceNaN(Psized, Double.MIN_VALUE);
                CommonOps.elementMult((D1Matrix64F)Psized, (D1Matrix64F)P);
                EjmlOps.replaceNaN(Psized, Double.MIN_VALUE);
                double C = CommonOps.elementSum((D1Matrix64F)Psized);
                end = System.currentTimeMillis();
                System.out.printf("Iteration %d: error is %f (50 iterations in %4.2f seconds)\n", iter, C, (double)(end - start) / 1000.0);
                if (C < 0.0) {
                    System.err.println("Warning: Error is negative, this is usually a very bad sign!");
                }
                start = System.currentTimeMillis();
            } else if (iter % 10 == 0) {
                end = System.currentTimeMillis();
                System.out.printf("Iteration %d: (10 iterations in %4.2f seconds)\n", iter, (double)(end - start) / 1000.0);
                start = System.currentTimeMillis();
            }
            if (iter != 100) continue;
            CommonOps.divide((D1Matrix64F)P, (double)4.0);
        }
        return MatrixOps.extractDoubleArray(Y);
    }
}

