/*
 * Decompiled with CFR 0.152.
 */
package com.datumbox.framework.core.machinelearning.featureselection;

import com.datumbox.framework.common.Configuration;
import com.datumbox.framework.common.concurrency.StreamMethods;
import com.datumbox.framework.common.dataobjects.AssociativeArray;
import com.datumbox.framework.common.dataobjects.TypeInference;
import com.datumbox.framework.common.storage.interfaces.BigMap;
import com.datumbox.framework.common.storage.interfaces.StorageEngine;
import com.datumbox.framework.core.common.dataobjects.Dataframe;
import com.datumbox.framework.core.common.dataobjects.DataframeMatrix;
import com.datumbox.framework.core.common.dataobjects.Record;
import com.datumbox.framework.core.machinelearning.common.abstracts.AbstractTrainer;
import com.datumbox.framework.core.machinelearning.common.abstracts.featureselectors.AbstractFeatureSelector;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import org.apache.commons.math3.linear.ArrayRealVector;
import org.apache.commons.math3.linear.DiagonalMatrix;
import org.apache.commons.math3.linear.EigenDecomposition;
import org.apache.commons.math3.linear.OpenMapRealVector;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.RealVector;
import org.apache.commons.math3.util.FastMath;

public class PCA
extends AbstractFeatureSelector<ModelParameters, TrainingParameters> {
    protected PCA(TrainingParameters trainingParameters, Configuration configuration) {
        super(trainingParameters, configuration);
    }

    protected PCA(String storageName, Configuration configuration) {
        super(storageName, configuration);
    }

    @Override
    public void fit(Dataframe trainingData) {
        Set<TypeInference.DataType> supportedXDataTypes = this.getSupportedXDataTypes();
        for (TypeInference.DataType d : trainingData.getXDataTypes().values()) {
            if (supportedXDataTypes.contains(d)) continue;
            throw new IllegalArgumentException("A DataType that is not supported by this method was detected in the Dataframe.");
        }
        super.fit(trainingData);
    }

    @Override
    protected void _fit(Dataframe trainingData) {
        ModelParameters modelParameters = (ModelParameters)this.knowledgeBase.getModelParameters();
        int n = trainingData.size();
        int d = trainingData.xColumnSize();
        Map<Object, Integer> featureIds = modelParameters.getFeatureIds();
        DataframeMatrix matrixDataset = DataframeMatrix.newInstance(trainingData, false, null, featureIds);
        RealMatrix X = matrixDataset.getX();
        OpenMapRealVector meanValues = new OpenMapRealVector(d);
        for (Integer columnId : featureIds.values()) {
            int row;
            double mean = 0.0;
            for (row = 0; row < n; ++row) {
                mean += X.getEntry(row, columnId.intValue());
            }
            mean /= (double)n;
            for (row = 0; row < n; ++row) {
                X.addToEntry(row, columnId.intValue(), -mean);
            }
            meanValues.setEntry(columnId.intValue(), mean);
        }
        modelParameters.setMean((RealVector)meanValues);
        RealMatrix covarianceDD = X.transpose().multiply(X).scalarMultiply(1.0 / ((double)n - 1.0));
        EigenDecomposition decomposition = new EigenDecomposition(covarianceDD);
        ArrayRealVector eigenValues = new ArrayRealVector(decomposition.getRealEigenvalues(), false);
        RealMatrix components = decomposition.getV();
        if (((TrainingParameters)this.knowledgeBase.getTrainingParameters()).isWhitened()) {
            DiagonalMatrix sqrtEigenValues = new DiagonalMatrix(d);
            for (int i = 0; i < d; ++i) {
                sqrtEigenValues.setEntry(i, i, FastMath.sqrt((double)eigenValues.getEntry(i)));
            }
            components = components.multiply((RealMatrix)sqrtEigenValues);
        }
        Integer maxDimensions = ((TrainingParameters)this.knowledgeBase.getTrainingParameters()).getMaxDimensions();
        Double variancePercentageThreshold = ((TrainingParameters)this.knowledgeBase.getTrainingParameters()).getVariancePercentageThreshold();
        if (variancePercentageThreshold != null && variancePercentageThreshold <= 1.0) {
            double totalVariance = 0.0;
            for (int i = 0; i < d; ++i) {
                totalVariance += eigenValues.getEntry(i);
            }
            double sum = 0.0;
            int varCounter = 0;
            for (int i = 0; i < d; ++i) {
                ++varCounter;
                if ((sum += eigenValues.getEntry(i) / totalVariance) >= variancePercentageThreshold) break;
            }
            if (maxDimensions == null || maxDimensions > varCounter) {
                maxDimensions = varCounter;
            }
        }
        if (maxDimensions != null && maxDimensions < d) {
            eigenValues = eigenValues.getSubVector(0, maxDimensions.intValue());
            components = components.getSubMatrix(0, components.getRowDimension() - 1, 0, maxDimensions - 1);
        }
        modelParameters.setEigenValues((RealVector)eigenValues);
        modelParameters.setComponents(components);
    }

    @Override
    protected void _transform(Dataframe newData) {
        ModelParameters modelParameters = (ModelParameters)this.knowledgeBase.getModelParameters();
        Map<Object, Integer> featureIds = modelParameters.getFeatureIds();
        HashMap<Integer, Integer> recordIdsReference = new HashMap<Integer, Integer>();
        DataframeMatrix matrixDataset = DataframeMatrix.parseDataset(newData, recordIdsReference, featureIds);
        RealMatrix components = modelParameters.getComponents();
        RealMatrix X = matrixDataset.getX().multiply(components);
        this.streamExecutor.forEach(StreamMethods.stream(newData.entries(), (boolean)this.isParallelized()), e -> {
            Integer rId = (Integer)e.getKey();
            Record r = (Record)e.getValue();
            int rowId = (Integer)recordIdsReference.get(rId);
            AssociativeArray xData = new AssociativeArray();
            int componentId = 0;
            for (double value : X.getRow(rowId)) {
                xData.put((Object)componentId++, (Object)value);
            }
            Record newR = new Record(xData, r.getY(), r.getYPredicted(), r.getYPredictedProbabilities());
            newData._unsafe_set(rId, newR);
        });
        newData.recalculateMeta();
    }

    @Override
    protected Set<TypeInference.DataType> getSupportedXDataTypes() {
        return new HashSet<TypeInference.DataType>(Arrays.asList(TypeInference.DataType.BOOLEAN, TypeInference.DataType.NUMERICAL));
    }

    @Override
    protected Set<TypeInference.DataType> getSupportedYDataTypes() {
        return null;
    }

    public static class TrainingParameters
    extends AbstractTrainer.AbstractTrainingParameters {
        private static final long serialVersionUID = 1L;
        private boolean whitened = false;
        private Integer maxDimensions = null;
        private Double variancePercentageThreshold = null;

        public boolean isWhitened() {
            return this.whitened;
        }

        public void setWhitened(boolean whitened) {
            this.whitened = whitened;
        }

        public Integer getMaxDimensions() {
            return this.maxDimensions;
        }

        public void setMaxDimensions(Integer maxDimensions) {
            this.maxDimensions = maxDimensions;
        }

        public Double getVariancePercentageThreshold() {
            return this.variancePercentageThreshold;
        }

        public void setVariancePercentageThreshold(Double variancePercentageThreshold) {
            this.variancePercentageThreshold = variancePercentageThreshold;
        }
    }

    public static class ModelParameters
    extends AbstractTrainer.AbstractModelParameters {
        private static final long serialVersionUID = 2L;
        @BigMap(keyClass=Object.class, valueClass=Integer.class, mapType=StorageEngine.MapType.HASHMAP, storageHint=StorageEngine.StorageHint.IN_MEMORY, concurrent=false)
        private Map<Object, Integer> featureIds;
        private RealVector mean;
        private RealVector eigenValues;
        private RealMatrix components;

        protected ModelParameters(StorageEngine storageEngine) {
            super(storageEngine);
        }

        public Map<Object, Integer> getFeatureIds() {
            return this.featureIds;
        }

        protected void setFeatureIds(Map<Object, Integer> featureIds) {
            this.featureIds = featureIds;
        }

        public RealVector getMean() {
            return this.mean;
        }

        protected void setMean(RealVector mean) {
            this.mean = mean;
        }

        public RealVector getEigenValues() {
            return this.eigenValues;
        }

        protected void setEigenValues(RealVector eigenValues) {
            this.eigenValues = eigenValues;
        }

        public RealMatrix getComponents() {
            return this.components;
        }

        protected void setComponents(RealMatrix components) {
            this.components = components;
        }
    }
}

