/*
 * Decompiled with CFR 0.152.
 */
package com.datumbox.framework.core.machinelearning.classification;

import com.datumbox.framework.common.Configuration;
import com.datumbox.framework.common.concurrency.ForkJoinStream;
import com.datumbox.framework.common.concurrency.StreamMethods;
import com.datumbox.framework.common.dataobjects.AssociativeArray;
import com.datumbox.framework.common.dataobjects.TypeInference;
import com.datumbox.framework.common.storage.interfaces.BigMap;
import com.datumbox.framework.common.storage.interfaces.StorageEngine;
import com.datumbox.framework.core.common.dataobjects.Dataframe;
import com.datumbox.framework.core.common.dataobjects.Record;
import com.datumbox.framework.core.machinelearning.common.abstracts.AbstractTrainer;
import com.datumbox.framework.core.machinelearning.common.abstracts.modelers.AbstractClassifier;
import com.datumbox.framework.core.machinelearning.common.interfaces.PredictParallelizable;
import com.datumbox.framework.core.machinelearning.common.interfaces.TrainParallelizable;
import com.datumbox.framework.core.statistics.descriptivestatistics.Descriptives;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.atomic.AtomicBoolean;

public class MaximumEntropy
extends AbstractClassifier<ModelParameters, TrainingParameters>
implements PredictParallelizable,
TrainParallelizable {
    private boolean parallelized = true;
    protected final ForkJoinStream streamExecutor = new ForkJoinStream(this.knowledgeBase.getConfiguration().getConcurrencyConfiguration());

    protected MaximumEntropy(TrainingParameters trainingParameters, Configuration configuration) {
        super(trainingParameters, configuration);
    }

    protected MaximumEntropy(String storageName, Configuration configuration) {
        super(storageName, configuration);
    }

    @Override
    public boolean isParallelized() {
        return this.parallelized;
    }

    @Override
    public void setParallelized(boolean parallelized) {
        this.parallelized = parallelized;
    }

    @Override
    protected void _predict(Dataframe newData) {
        this._predictDatasetParallel(newData, this.knowledgeBase.getStorageEngine(), this.knowledgeBase.getConfiguration().getConcurrencyConfiguration());
    }

    @Override
    public PredictParallelizable.Prediction _predictRecord(Record r) {
        Set<Object> classesSet = ((ModelParameters)this.knowledgeBase.getModelParameters()).getClasses();
        AssociativeArray predictionScores = new AssociativeArray();
        for (Object theClass : classesSet) {
            predictionScores.put(theClass, (Object)this.calculateClassScore(r.getX(), theClass));
        }
        Object predictedClass = this.getSelectedClassFromClassScores(predictionScores);
        Descriptives.normalizeExp(predictionScores);
        return new PredictParallelizable.Prediction(predictedClass, predictionScores);
    }

    @Override
    protected void _fit(Dataframe trainingData) {
        ModelParameters modelParameters = (ModelParameters)this.knowledgeBase.getModelParameters();
        int n = trainingData.size();
        Map<List<Object>, Double> lambdas = modelParameters.getLambdas();
        Set<Object> classesSet = modelParameters.getClasses();
        double Cmax = 0.0;
        for (Record r2 : trainingData) {
            Object theClass = r2.getY();
            classesSet.add(theClass);
            int activeFeatures = (int)r2.getX().values().stream().filter(e -> e != null && TypeInference.toDouble((Object)e) > 0.0).count();
            if (!((double)activeFeatures > Cmax)) continue;
            Cmax = activeFeatures;
        }
        StorageEngine storageEngine = this.knowledgeBase.getStorageEngine();
        Map tmp_EpFj_observed = storageEngine.getBigMap("tmp_EpFj_observed", List.class, Double.class, StorageEngine.MapType.HASHMAP, StorageEngine.StorageHint.IN_MEMORY, true, true);
        this.streamExecutor.forEach(StreamMethods.stream(trainingData.getXDataTypes().keySet().stream(), (boolean)this.isParallelized()), feature -> {
            for (Object theClass : classesSet) {
                List<Object> featureClassTuple = Arrays.asList(feature, theClass);
                tmp_EpFj_observed.put(featureClassTuple, 0.0);
                lambdas.put(featureClassTuple, 0.0);
            }
        });
        double increment = 1.0 / (double)n;
        this.streamExecutor.forEach(StreamMethods.stream(trainingData.stream(), (boolean)this.isParallelized()), r -> {
            Object theClass = r.getY();
            for (Map.Entry entry : r.getX().entrySet()) {
                Double occurrences = TypeInference.toDouble(entry.getValue());
                if (occurrences == null || !(occurrences > 0.0)) continue;
                Object feature = entry.getKey();
                List<Object> featureClassTuple = Arrays.asList(feature, theClass);
                Map map = tmp_EpFj_observed;
                synchronized (map) {
                    tmp_EpFj_observed.put(featureClassTuple, (Double)tmp_EpFj_observed.get(featureClassTuple) + increment);
                }
            }
        });
        this.IIS(trainingData, tmp_EpFj_observed, Cmax);
        storageEngine.dropBigMap("tmp_EpFj_observed", tmp_EpFj_observed);
    }

    private void IIS(Dataframe trainingData, Map<List<Object>, Double> EpFj_observed, double Cmax) {
        ModelParameters modelParameters = (ModelParameters)this.knowledgeBase.getModelParameters();
        int totalIterations = ((TrainingParameters)this.knowledgeBase.getTrainingParameters()).getTotalIterations();
        Set<Object> classesSet = modelParameters.getClasses();
        Map<List<Object>, Double> lambdas = modelParameters.getLambdas();
        int n = trainingData.size();
        StorageEngine storageEngine = this.knowledgeBase.getStorageEngine();
        for (int iteration = 0; iteration < totalIterations; ++iteration) {
            this.logger.debug("Iteration {}", (Object)iteration);
            Map tmp_EpFj_model = storageEngine.getBigMap("tmp_EpFj_model", List.class, Double.class, StorageEngine.MapType.HASHMAP, StorageEngine.StorageHint.IN_MEMORY, false, true);
            this.streamExecutor.forEach(StreamMethods.stream(trainingData.stream(), (boolean)this.isParallelized()), r -> {
                AssociativeArray classScores = new AssociativeArray();
                AssociativeArray xData = r.getX();
                for (Object theClass : classesSet) {
                    double score = this.calculateClassScore(xData, theClass);
                    classScores.put(theClass, (Object)score);
                }
                Descriptives.normalizeExp(classScores);
                for (Map.Entry entry : classScores.entrySet()) {
                    Object theClass = entry.getKey();
                    Double score = TypeInference.toDouble(entry.getValue());
                    double probabilityFraction = score / (double)n;
                    Map map = tmp_EpFj_model;
                    synchronized (map) {
                        for (Map.Entry entry2 : r.getX().entrySet()) {
                            Double occurrences = TypeInference.toDouble(entry2.getValue());
                            if (occurrences == null || occurrences == 0.0) continue;
                            Object feature = entry2.getKey();
                            List<Object> featureClassTuple = Arrays.asList(feature, theClass);
                            tmp_EpFj_model.put(featureClassTuple, tmp_EpFj_model.getOrDefault(featureClassTuple, 0.0) + probabilityFraction);
                        }
                    }
                }
            });
            AtomicBoolean infiniteValuesDetected = new AtomicBoolean(false);
            this.streamExecutor.forEach(StreamMethods.stream(tmp_EpFj_model.entrySet().stream(), (boolean)this.isParallelized()), featureClassCounts -> {
                List tp = (List)featureClassCounts.getKey();
                Double EpFj_observed_value = (Double)EpFj_observed.get(tp);
                Double EpFj_model_value = (Double)featureClassCounts.getValue();
                if (!(Math.abs(EpFj_observed_value - EpFj_model_value) <= 1.0E-8)) {
                    if (EpFj_observed_value == 0.0) {
                        lambdas.put(tp, Double.NEGATIVE_INFINITY);
                        infiniteValuesDetected.set(true);
                    } else if (EpFj_model_value == 0.0) {
                        lambdas.put(tp, Double.POSITIVE_INFINITY);
                        infiniteValuesDetected.set(true);
                    } else {
                        double deltaJ = Math.log(EpFj_observed_value / EpFj_model_value) / Cmax;
                        double newValue = (Double)lambdas.get(tp) + deltaJ;
                        lambdas.put(tp, newValue);
                    }
                }
            });
            if (infiniteValuesDetected.get()) {
                Double minimumNonInfiniteLambdaWeight = (Double)this.streamExecutor.min(StreamMethods.stream(lambdas.values().stream(), (boolean)this.isParallelized()).filter(v -> Double.isFinite(v)), Double::compare).get();
                Double maximumNonInfiniteLambdaWeight = (Double)this.streamExecutor.max(StreamMethods.stream(lambdas.values().stream(), (boolean)this.isParallelized()).filter(v -> Double.isFinite(v)), Double::compare).get();
                this.streamExecutor.forEach(StreamMethods.stream(lambdas.entrySet().stream(), (boolean)this.isParallelized()), e -> {
                    List featureClass = (List)e.getKey();
                    Double value = (Double)e.getValue();
                    if (Double.isInfinite(value)) {
                        if (value < 0.0) {
                            lambdas.put(featureClass, minimumNonInfiniteLambdaWeight);
                        } else {
                            lambdas.put(featureClass, maximumNonInfiniteLambdaWeight);
                        }
                    }
                });
            }
            storageEngine.dropBigMap("tmp_EpFj_model", tmp_EpFj_model);
        }
    }

    private Double calculateClassScore(AssociativeArray x, Object theClass) {
        double score = 0.0;
        Map<List<Object>, Double> lambdas = ((ModelParameters)this.knowledgeBase.getModelParameters()).getLambdas();
        for (Map.Entry entry : x.entrySet()) {
            Double value = TypeInference.toDouble(entry.getValue());
            if (value == null || value == 0.0) continue;
            Object feature = entry.getKey();
            List<Object> featureClassTuple = Arrays.asList(feature, theClass);
            Double lambdaWeight = lambdas.get(featureClassTuple);
            if (lambdaWeight == null) continue;
            score += lambdaWeight.doubleValue();
        }
        return score;
    }

    public static class TrainingParameters
    extends AbstractTrainer.AbstractTrainingParameters {
        private static final long serialVersionUID = 1L;
        private int totalIterations = 100;

        public int getTotalIterations() {
            return this.totalIterations;
        }

        public void setTotalIterations(int totalIterations) {
            this.totalIterations = totalIterations;
        }
    }

    public static class ModelParameters
    extends AbstractClassifier.AbstractModelParameters {
        private static final long serialVersionUID = 1L;
        @BigMap(keyClass=List.class, valueClass=Double.class, mapType=StorageEngine.MapType.HASHMAP, storageHint=StorageEngine.StorageHint.IN_MEMORY, concurrent=true)
        private Map<List<Object>, Double> lambdas;

        protected ModelParameters(StorageEngine storageEngine) {
            super(storageEngine);
        }

        public Map<List<Object>, Double> getLambdas() {
            return this.lambdas;
        }

        protected void setLambdas(Map<List<Object>, Double> lambdas) {
            this.lambdas = lambdas;
        }
    }
}

