/*
 * Decompiled with CFR 0.152.
 */
package com.jujutsu.tsne;

import com.jujutsu.tsne.PrincipalComponentAnalysis;
import com.jujutsu.tsne.TSne;
import com.jujutsu.tsne.barneshut.TSneConfiguration;
import com.jujutsu.utils.EjmlOps;
import com.jujutsu.utils.MatrixOps;
import java.io.BufferedInputStream;
import java.io.DataInputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.IOException;
import org.ejml.data.D1Matrix64F;
import org.ejml.data.DenseMatrix64F;
import org.ejml.ops.CommonOps;

public class FastTSne
implements TSne {
    MatrixOps mo = new MatrixOps();
    protected volatile boolean abort = false;

    public static double[][] readBinaryDoubleMatrix(int rows, int columns, String fn) throws FileNotFoundException, IOException {
        File matrixFile = new File(fn);
        double[][] matrix = new double[rows][columns];
        try (DataInputStream dis = new DataInputStream(new BufferedInputStream(new FileInputStream(matrixFile.getAbsolutePath())));){
            for (int i = 0; i < matrix.length; ++i) {
                for (int j = 0; j < matrix[0].length; ++j) {
                    matrix[i][j] = dis.readDouble();
                }
            }
        }
        return matrix;
    }

    @Override
    public double[][] tsne(TSneConfiguration config) {
        double[][] X = config.getXin();
        int no_dims = config.getOutputDims();
        int initial_dims = config.getInitialDims();
        double perplexity = config.getPerplexity();
        int max_iter = config.getMaxIter();
        boolean use_pca = config.usePca();
        String IMPLEMENTATION_NAME = this.getClass().getSimpleName();
        System.out.println("X:Shape is = " + X.length + " x " + X[0].length);
        System.out.println("Running " + IMPLEMENTATION_NAME + ".");
        long end = System.currentTimeMillis();
        long start = System.currentTimeMillis();
        if (use_pca && X[0].length > initial_dims && initial_dims > 0) {
            PrincipalComponentAnalysis pca = new PrincipalComponentAnalysis();
            X = pca.pca(X, initial_dims);
            System.out.println("X:Shape after PCA is = " + X.length + " x " + X[0].length);
        }
        int n = X.length;
        double momentum = 0.5;
        double initial_momentum = 0.5;
        double final_momentum = 0.8;
        int eta = 500;
        double min_gain = 0.01;
        DenseMatrix64F Y = new DenseMatrix64F(MatrixOps.rnorm(n, no_dims));
        DenseMatrix64F Ysqlmul = new DenseMatrix64F(Y.numRows, Y.numRows);
        DenseMatrix64F dY = new DenseMatrix64F(MatrixOps.fillMatrix(n, no_dims, 0.0));
        DenseMatrix64F iY = new DenseMatrix64F(MatrixOps.fillMatrix(n, no_dims, 0.0));
        DenseMatrix64F gains = new DenseMatrix64F(MatrixOps.fillMatrix(n, no_dims, 1.0));
        DenseMatrix64F btNeg = new DenseMatrix64F(n, no_dims);
        DenseMatrix64F bt = new DenseMatrix64F(n, no_dims);
        DenseMatrix64F P = new DenseMatrix64F(this.x2p((double[][])X, (double)1.0E-5, (double)perplexity).P);
        DenseMatrix64F Ptr = new DenseMatrix64F(P.numRows, P.numCols);
        DenseMatrix64F L = new DenseMatrix64F(P);
        DenseMatrix64F logdivide = new DenseMatrix64F(P.numRows, P.numCols);
        DenseMatrix64F diag = new DenseMatrix64F(MatrixOps.fillMatrix(L.numRows, L.numCols, 0.0));
        CommonOps.transpose(P, Ptr);
        CommonOps.addEquals(P, Ptr);
        CommonOps.divide(P, CommonOps.elementSum(P));
        EjmlOps.replaceNaN(P, Double.MIN_VALUE);
        CommonOps.scale(4.0, P);
        EjmlOps.maximize(P, 1.0E-12);
        System.out.println("Y:Shape is = " + Y.getNumRows() + " x " + Y.getNumCols());
        DenseMatrix64F sqed = new DenseMatrix64F(Y.numRows, Y.numCols);
        DenseMatrix64F sum_Y = new DenseMatrix64F(1, Y.numRows);
        DenseMatrix64F num = new DenseMatrix64F(Y.numRows, Y.numRows);
        DenseMatrix64F Q = new DenseMatrix64F(P.numRows, P.numCols);
        for (int iter = 0; iter < max_iter && !this.abort; ++iter) {
            CommonOps.elementPower((D1Matrix64F)Y, 2.0, (D1Matrix64F)sqed);
            CommonOps.sumRows(sqed, sum_Y);
            CommonOps.multAddTransB(-2.0, Y, Y, Ysqlmul);
            EjmlOps.addRowVector(Ysqlmul, sum_Y);
            CommonOps.transpose(Ysqlmul);
            EjmlOps.addRowVector(Ysqlmul, sum_Y);
            CommonOps.add(Ysqlmul, 1.0);
            CommonOps.divide(1.0, Ysqlmul);
            num.set(Ysqlmul);
            EjmlOps.assignAtIndex(num, MatrixOps.range(n), MatrixOps.range(n), 0.0);
            CommonOps.divide(num, CommonOps.elementSum(num), (D1Matrix64F)Q);
            EjmlOps.maximize(Q, 1.0E-12);
            CommonOps.subtract((D1Matrix64F)P, (D1Matrix64F)Q, (D1Matrix64F)L);
            CommonOps.elementMult(L, num);
            DenseMatrix64F rowsum = CommonOps.sumRows(L, null);
            double[] rsum = new double[rowsum.numRows];
            for (int i = 0; i < rsum.length; ++i) {
                rsum[i] = rowsum.get(i, 0);
            }
            EjmlOps.setDiag(diag, rsum);
            CommonOps.subtract((D1Matrix64F)diag, (D1Matrix64F)L, (D1Matrix64F)L);
            CommonOps.mult(L, Y, dY);
            CommonOps.scale(4.0, dY);
            momentum = iter < 20 ? initial_momentum : final_momentum;
            boolean[][] boolMtrx = MatrixOps.equal(EjmlOps.biggerThan(dY, 0.0), EjmlOps.biggerThan(iY, 0.0));
            EjmlOps.setData(btNeg, MatrixOps.abs(MatrixOps.negate(boolMtrx)));
            EjmlOps.setData(bt, MatrixOps.abs(boolMtrx));
            DenseMatrix64F gainsSmall = new DenseMatrix64F(gains);
            DenseMatrix64F gainsBig = new DenseMatrix64F(gains);
            CommonOps.add(gainsSmall, 0.2);
            CommonOps.scale(0.8, gainsBig);
            CommonOps.elementMult(gainsSmall, btNeg);
            CommonOps.elementMult(gainsBig, bt);
            CommonOps.add((D1Matrix64F)gainsSmall, gainsBig, (D1Matrix64F)gains);
            EjmlOps.assignAllLessThan(gains, min_gain, min_gain);
            CommonOps.scale(momentum, iY);
            DenseMatrix64F gainsdY = new DenseMatrix64F(gains.numRows, dY.numCols);
            CommonOps.elementMult(gains, dY, gainsdY);
            CommonOps.scale(eta, gainsdY);
            CommonOps.subtractEquals(iY, gainsdY);
            CommonOps.addEquals(Y, iY);
            DenseMatrix64F colMeanY = EjmlOps.colMean(Y, 0);
            DenseMatrix64F meanTile = EjmlOps.tile(colMeanY, n, 1);
            CommonOps.subtractEquals(Y, meanTile);
            if (iter % 100 == 0) {
                DenseMatrix64F Pdiv = new DenseMatrix64F(P);
                CommonOps.elementDiv(Pdiv, Q);
                CommonOps.elementLog(Pdiv, logdivide);
                EjmlOps.replaceNaN(logdivide, Double.MIN_VALUE);
                CommonOps.elementMult(logdivide, P);
                EjmlOps.replaceNaN(logdivide, Double.MIN_VALUE);
                double C = CommonOps.elementSum(logdivide);
                end = System.currentTimeMillis();
                System.out.printf("Iteration %d: error is %f (50 iterations in %4.2f seconds)\n", iter, C, (double)(end - start) / 1000.0);
                if (C < 0.0) {
                    System.err.println("Warning: Error is negative, this is usually a very bad sign!");
                }
                start = System.currentTimeMillis();
            } else if (iter % 10 == 0) {
                end = System.currentTimeMillis();
                System.out.printf("Iteration %d: (10 iterations in %4.2f seconds)\n", iter, (double)(end - start) / 1000.0);
                start = System.currentTimeMillis();
            }
            if (iter != 100) continue;
            CommonOps.divide(P, 4.0);
        }
        return EjmlOps.extractDoubleArray(Y);
    }

    public TSne.R Hbeta(double[][] D, double beta) {
        DenseMatrix64F P = new DenseMatrix64F(D);
        CommonOps.scale(-beta, P);
        CommonOps.elementExp(P, P);
        double sumP = CommonOps.elementSum(P);
        DenseMatrix64F Dd = new DenseMatrix64F(D);
        CommonOps.elementMult(Dd, P);
        double H = Math.log(sumP) + beta * CommonOps.elementSum(Dd) / sumP;
        CommonOps.scale(1.0 / sumP, P);
        TSne.R r = new TSne.R();
        r.H = H;
        r.P = EjmlOps.extractDoubleArray(P);
        return r;
    }

    public TSne.R x2p(double[][] X, double tol, double perplexity) {
        int n = X.length;
        double[][] sum_X = MatrixOps.sum(MatrixOps.square(X), 1);
        double[][] times = MatrixOps.scalarMult(MatrixOps.times(X, this.mo.transpose(X)), -2.0);
        double[][] prodSum = MatrixOps.addColumnVector(this.mo.transpose(times), sum_X);
        double[][] D = MatrixOps.addRowVector(prodSum, this.mo.transpose(sum_X));
        double[][] P = MatrixOps.fillMatrix(n, n, 0.0);
        double[] beta = MatrixOps.fillMatrix(n, n, 1.0)[0];
        double logU = Math.log(perplexity);
        System.out.println("Starting x2p...");
        for (int i = 0; i < n; ++i) {
            if (i % 500 == 0) {
                System.out.println("Computing P-values for point " + i + " of " + n + "...");
            }
            double betamin = Double.NEGATIVE_INFINITY;
            double betamax = Double.POSITIVE_INFINITY;
            double[][] Di = MatrixOps.getValuesFromRow(D, i, MatrixOps.concatenate(MatrixOps.range(0, i), MatrixOps.range(i + 1, n)));
            TSne.R hbeta = this.Hbeta(Di, beta[i]);
            double H = hbeta.H;
            double[][] thisP = hbeta.P;
            double Hdiff = H - logU;
            for (int tries = 0; Math.abs(Hdiff) > tol && tries < 50; ++tries) {
                if (Hdiff > 0.0) {
                    betamin = beta[i];
                    beta[i] = Double.isInfinite(betamax) ? beta[i] * 2.0 : (beta[i] + betamax) / 2.0;
                } else {
                    betamax = beta[i];
                    beta[i] = Double.isInfinite(betamin) ? beta[i] / 2.0 : (beta[i] + betamin) / 2.0;
                }
                hbeta = this.Hbeta(Di, beta[i]);
                H = hbeta.H;
                thisP = hbeta.P;
                Hdiff = H - logU;
            }
            MatrixOps.assignValuesToRow(P, i, MatrixOps.concatenate(MatrixOps.range(0, i), MatrixOps.range(i + 1, n)), thisP[0]);
        }
        TSne.R r = new TSne.R();
        r.P = P;
        r.beta = beta;
        double sigma = MatrixOps.mean(MatrixOps.sqrt(MatrixOps.scalarInverse(beta)));
        System.out.println("Mean value of sigma: " + sigma);
        return r;
    }

    @Override
    public void abort() {
        this.abort = true;
    }
}

