/*
 * Decompiled with CFR 0.152.
 */
package com.datumbox.framework.core.machinelearning.classification;

import com.datumbox.framework.common.Configuration;
import com.datumbox.framework.common.concurrency.ForkJoinStream;
import com.datumbox.framework.common.concurrency.StreamMethods;
import com.datumbox.framework.common.dataobjects.AssociativeArray;
import com.datumbox.framework.common.dataobjects.TypeInference;
import com.datumbox.framework.common.storage.interfaces.BigMap;
import com.datumbox.framework.common.storage.interfaces.StorageEngine;
import com.datumbox.framework.core.common.dataobjects.Dataframe;
import com.datumbox.framework.core.common.dataobjects.Record;
import com.datumbox.framework.core.machinelearning.common.abstracts.AbstractTrainer;
import com.datumbox.framework.core.machinelearning.common.abstracts.modelers.AbstractClassifier;
import com.datumbox.framework.core.machinelearning.common.interfaces.PredictParallelizable;
import com.datumbox.framework.core.machinelearning.common.interfaces.TrainParallelizable;
import com.datumbox.framework.core.mathematics.regularization.ElasticNetRegularizer;
import com.datumbox.framework.core.mathematics.regularization.L1Regularizer;
import com.datumbox.framework.core.mathematics.regularization.L2Regularizer;
import com.datumbox.framework.core.statistics.descriptivestatistics.Descriptives;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Set;

public class SoftMaxRegression
extends AbstractClassifier<ModelParameters, TrainingParameters>
implements PredictParallelizable,
TrainParallelizable {
    private boolean parallelized = true;
    protected final ForkJoinStream streamExecutor = new ForkJoinStream(this.knowledgeBase.getConfiguration().getConcurrencyConfiguration());

    protected SoftMaxRegression(TrainingParameters trainingParameters, Configuration configuration) {
        super(trainingParameters, configuration);
    }

    protected SoftMaxRegression(String storageName, Configuration configuration) {
        super(storageName, configuration);
    }

    @Override
    public boolean isParallelized() {
        return this.parallelized;
    }

    @Override
    public void setParallelized(boolean parallelized) {
        this.parallelized = parallelized;
    }

    @Override
    protected void _predict(Dataframe newData) {
        this._predictDatasetParallel(newData, this.knowledgeBase.getStorageEngine(), this.knowledgeBase.getConfiguration().getConcurrencyConfiguration());
    }

    @Override
    public PredictParallelizable.Prediction _predictRecord(Record r) {
        ModelParameters modelParameters = (ModelParameters)this.knowledgeBase.getModelParameters();
        Set<Object> classesSet = modelParameters.getClasses();
        Map<List<Object>, Double> thitas = modelParameters.getThitas();
        AssociativeArray predictionScores = new AssociativeArray();
        for (Object theClass : classesSet) {
            predictionScores.put(theClass, (Object)this.calculateClassScore(r.getX(), theClass, thitas));
        }
        Object predictedClass = this.getSelectedClassFromClassScores(predictionScores);
        Descriptives.normalizeExp(predictionScores);
        return new PredictParallelizable.Prediction(predictedClass, predictionScores);
    }

    @Override
    protected void _fit(Dataframe trainingData) {
        ModelParameters modelParameters = (ModelParameters)this.knowledgeBase.getModelParameters();
        TrainingParameters trainingParameters = (TrainingParameters)this.knowledgeBase.getTrainingParameters();
        Map<List<Object>, Double> thitas = modelParameters.getThitas();
        Set<Object> classesSet = modelParameters.getClasses();
        for (Record r : trainingData) {
            Object theClass = r.getY();
            classesSet.add(theClass);
        }
        for (Object theClass : classesSet) {
            thitas.put(Arrays.asList("~CONSTANT", theClass), 0.0);
        }
        this.streamExecutor.forEach(StreamMethods.stream(trainingData.getXDataTypes().keySet().stream(), (boolean)this.isParallelized()), feature -> {
            for (Object theClass : classesSet) {
                thitas.putIfAbsent(Arrays.asList(feature, theClass), 0.0);
            }
        });
        double minError = Double.POSITIVE_INFINITY;
        double learningRate = trainingParameters.getLearningRate();
        int totalIterations = trainingParameters.getTotalIterations();
        StorageEngine storageEngine = this.knowledgeBase.getStorageEngine();
        for (int iteration = 0; iteration < totalIterations; ++iteration) {
            this.logger.debug("Iteration {}", (Object)iteration);
            Map tmp_newThitas = storageEngine.getBigMap("tmp_newThitas", List.class, Double.class, StorageEngine.MapType.HASHMAP, StorageEngine.StorageHint.IN_MEMORY, true, true);
            tmp_newThitas.putAll(thitas);
            this.batchGradientDescent(trainingData, tmp_newThitas, learningRate);
            double newError = this.calculateError(trainingData, tmp_newThitas);
            if (newError > minError) {
                learningRate /= 2.0;
            } else {
                learningRate *= 1.05;
                minError = newError;
                thitas.clear();
                thitas.putAll(tmp_newThitas);
            }
            storageEngine.dropBigMap("tmp_newThitas", tmp_newThitas);
        }
    }

    private void batchGradientDescent(Dataframe trainingData, Map<List<Object>, Double> newThitas, double learningRate) {
        ModelParameters modelParameters = (ModelParameters)this.knowledgeBase.getModelParameters();
        double multiplier = learningRate / (double)trainingData.size();
        Map<List<Object>, Double> thitas = modelParameters.getThitas();
        Set<Object> classesSet = modelParameters.getClasses();
        this.streamExecutor.forEach(StreamMethods.stream(trainingData.stream(), (boolean)this.isParallelized()), r -> {
            AssociativeArray classProbabilities = this.hypothesisFunction(r.getX(), thitas);
            for (Object theClass : classesSet) {
                double score = classProbabilities.getDouble(theClass);
                double error = r.getY().equals(theClass) ? 1.0 - score : -score;
                double errorMultiplier = multiplier * error;
                Map map = newThitas;
                synchronized (map) {
                    for (Map.Entry entry : r.getX().entrySet()) {
                        Double value = TypeInference.toDouble(entry.getValue());
                        Object feature = entry.getKey();
                        List<Object> featureClassTuple = Arrays.asList(feature, theClass);
                        newThitas.put(featureClassTuple, (Double)newThitas.get(featureClassTuple) + errorMultiplier * value);
                    }
                    List<Object> featureClassTuple = Arrays.asList("~CONSTANT", theClass);
                    newThitas.put(featureClassTuple, (Double)newThitas.get(featureClassTuple) + errorMultiplier);
                }
            }
        });
        double l1 = ((TrainingParameters)this.knowledgeBase.getTrainingParameters()).getL1();
        double l2 = ((TrainingParameters)this.knowledgeBase.getTrainingParameters()).getL2();
        if (l1 > 0.0 && l2 > 0.0) {
            ElasticNetRegularizer.updateWeights(l1, l2, learningRate, thitas, newThitas);
        } else if (l1 > 0.0) {
            L1Regularizer.updateWeights(l1, learningRate, thitas, newThitas);
        } else if (l2 > 0.0) {
            L2Regularizer.updateWeights(l2, learningRate, thitas, newThitas);
        }
    }

    private Double calculateClassScore(AssociativeArray x, Object theClass, Map<List<Object>, Double> thitas) {
        double score = thitas.get(Arrays.asList("~CONSTANT", theClass));
        for (Map.Entry entry : x.entrySet()) {
            Double value = TypeInference.toDouble(entry.getValue());
            Object feature = entry.getKey();
            List<Object> featureClassTuple = Arrays.asList(feature, theClass);
            Double thitaWeight = thitas.get(featureClassTuple);
            if (thitaWeight == null) continue;
            score += thitaWeight * value;
        }
        return score;
    }

    private double calculateError(Dataframe trainingData, Map<List<Object>, Double> thitas) {
        double error = this.streamExecutor.sum(StreamMethods.stream(trainingData.stream(), (boolean)this.isParallelized()).mapToDouble(r -> {
            AssociativeArray classProbabilities = this.hypothesisFunction(r.getX(), thitas);
            Double score = classProbabilities.getDouble(r.getY());
            return Math.log(score);
        }));
        error = -error / (double)trainingData.size();
        double l1 = ((TrainingParameters)this.knowledgeBase.getTrainingParameters()).getL1();
        double l2 = ((TrainingParameters)this.knowledgeBase.getTrainingParameters()).getL2();
        if (l1 > 0.0 && l2 > 0.0) {
            error += ElasticNetRegularizer.estimatePenalty(l1, l2, thitas);
        } else if (l1 > 0.0) {
            error += L1Regularizer.estimatePenalty(l1, thitas);
        } else if (l2 > 0.0) {
            error += L2Regularizer.estimatePenalty(l2, thitas);
        }
        return error;
    }

    private AssociativeArray hypothesisFunction(AssociativeArray x, Map<List<Object>, Double> thitas) {
        Set<Object> classesSet = ((ModelParameters)this.knowledgeBase.getModelParameters()).getClasses();
        AssociativeArray predictionProbabilities = new AssociativeArray();
        for (Object theClass : classesSet) {
            double score = this.calculateClassScore(x, theClass, thitas);
            if (score <= 0.0) {
                score = 1.0E-8;
            }
            predictionProbabilities.put(theClass, (Object)score);
        }
        Descriptives.normalize(predictionProbabilities);
        return predictionProbabilities;
    }

    public static class TrainingParameters
    extends AbstractTrainer.AbstractTrainingParameters {
        private static final long serialVersionUID = 1L;
        private int totalIterations = 100;
        private double learningRate = 0.1;
        private double l1 = 0.0;
        private double l2 = 0.0;

        public int getTotalIterations() {
            return this.totalIterations;
        }

        public void setTotalIterations(int totalIterations) {
            this.totalIterations = totalIterations;
        }

        public double getLearningRate() {
            return this.learningRate;
        }

        public void setLearningRate(double learningRate) {
            this.learningRate = learningRate;
        }

        public double getL1() {
            return this.l1;
        }

        public void setL1(double l1) {
            this.l1 = l1;
        }

        public double getL2() {
            return this.l2;
        }

        public void setL2(double l2) {
            this.l2 = l2;
        }
    }

    public static class ModelParameters
    extends AbstractClassifier.AbstractModelParameters {
        private static final long serialVersionUID = 1L;
        @BigMap(keyClass=List.class, valueClass=Double.class, mapType=StorageEngine.MapType.HASHMAP, storageHint=StorageEngine.StorageHint.IN_MEMORY, concurrent=true)
        private Map<List<Object>, Double> thitas;

        protected ModelParameters(StorageEngine storageEngine) {
            super(storageEngine);
        }

        public Map<List<Object>, Double> getThitas() {
            return this.thitas;
        }

        protected void setThitas(Map<List<Object>, Double> thitas) {
            this.thitas = thitas;
        }
    }
}

