/*
 * Decompiled with CFR 0.152.
 */
package com.datumbox.framework.core.machinelearning.common.abstracts.algorithms;

import com.datumbox.framework.common.Configuration;
import com.datumbox.framework.common.concurrency.ForkJoinStream;
import com.datumbox.framework.common.concurrency.StreamMethods;
import com.datumbox.framework.common.dataobjects.AssociativeArray;
import com.datumbox.framework.common.dataobjects.TypeInference;
import com.datumbox.framework.common.storage.interfaces.BigMap;
import com.datumbox.framework.common.storage.interfaces.StorageEngine;
import com.datumbox.framework.core.common.dataobjects.Dataframe;
import com.datumbox.framework.core.common.dataobjects.Record;
import com.datumbox.framework.core.machinelearning.common.abstracts.AbstractTrainer;
import com.datumbox.framework.core.machinelearning.common.abstracts.modelers.AbstractClassifier;
import com.datumbox.framework.core.machinelearning.common.interfaces.PredictParallelizable;
import com.datumbox.framework.core.machinelearning.common.interfaces.TrainParallelizable;
import com.datumbox.framework.core.statistics.descriptivestatistics.Descriptives;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;

public abstract class AbstractNaiveBayes<MP extends AbstractModelParameters, TP extends AbstractTrainingParameters>
extends AbstractClassifier<MP, TP>
implements PredictParallelizable,
TrainParallelizable {
    private boolean parallelized = true;
    protected final ForkJoinStream streamExecutor = new ForkJoinStream(this.knowledgeBase.getConfiguration().getConcurrencyConfiguration());

    protected AbstractNaiveBayes(TP trainingParameters, Configuration configuration) {
        super(trainingParameters, configuration);
    }

    protected AbstractNaiveBayes(String storageName, Configuration configuration) {
        super(storageName, configuration);
    }

    protected abstract boolean isBinarized();

    @Override
    public boolean isParallelized() {
        return this.parallelized;
    }

    @Override
    public void setParallelized(boolean parallelized) {
        this.parallelized = parallelized;
    }

    @Override
    protected void _predict(Dataframe newData) {
        this._predictDatasetParallel(newData, this.knowledgeBase.getStorageEngine(), this.knowledgeBase.getConfiguration().getConcurrencyConfiguration());
    }

    @Override
    public PredictParallelizable.Prediction _predictRecord(Record r) {
        AbstractModelParameters modelParameters = (AbstractModelParameters)this.knowledgeBase.getModelParameters();
        Map<List<Object>, Double> logLikelihoods = modelParameters.getLogLikelihoods();
        Map<Object, Double> logPriors = modelParameters.getLogPriors();
        Set<Object> classesSet = modelParameters.getClasses();
        Object someClass = classesSet.iterator().next();
        boolean isBinarized = this.isBinarized();
        AssociativeArray predictionScores = new AssociativeArray(new HashMap<Object, Double>(logPriors));
        for (Map.Entry entry : r.getX().entrySet()) {
            Object feature = entry.getKey();
            if (!logLikelihoods.containsKey(Arrays.asList(feature, someClass))) continue;
            AssociativeArray classLogScoresForThisFeature = new AssociativeArray();
            for (Object theClass : classesSet) {
                Double logScore = logLikelihoods.get(Arrays.asList(feature, theClass));
                classLogScoresForThisFeature.put(theClass, (Object)logScore);
            }
            Double occurrences = TypeInference.toDouble(entry.getValue());
            if ((!((AbstractTrainingParameters)this.knowledgeBase.getTrainingParameters()).isMultiProbabilityWeighted() || isBinarized) && occurrences > 0.0) {
                occurrences = 1.0;
            }
            for (Map.Entry entry2 : classLogScoresForThisFeature.entrySet()) {
                Object theClass = entry2.getKey();
                Double logScore = TypeInference.toDouble(entry2.getValue());
                Double previousValue = predictionScores.getDouble(theClass);
                predictionScores.put(theClass, (Object)(previousValue + occurrences * logScore));
            }
        }
        Object predictedClass = this.getSelectedClassFromClassScores(predictionScores);
        Descriptives.normalizeExp(predictionScores);
        return new PredictParallelizable.Prediction(predictedClass, predictionScores);
    }

    @Override
    protected void _fit(Dataframe trainingData) {
        Object theClass;
        AbstractModelParameters modelParameters = (AbstractModelParameters)this.knowledgeBase.getModelParameters();
        int n = trainingData.size();
        int d = trainingData.xColumnSize();
        Map<List<Object>, Double> logLikelihoods = modelParameters.getLogLikelihoods();
        Map<Object, Double> logPriors = modelParameters.getLogPriors();
        Set<Object> classesSet = modelParameters.getClasses();
        boolean isBinarized = this.isBinarized();
        HashMap<Object, Double> totalFeatureOccurrencesForEachClass = new HashMap<Object, Double>();
        for (Record record : trainingData) {
            theClass = record.getY();
            if (classesSet.add(theClass)) {
                logPriors.put(theClass, 1.0);
                totalFeatureOccurrencesForEachClass.put(theClass, 0.0);
                continue;
            }
            logPriors.put(theClass, logPriors.get(theClass) + 1.0);
        }
        this.streamExecutor.forEach(StreamMethods.stream(trainingData.getXDataTypes().keySet().stream(), (boolean)this.isParallelized()), feature -> {
            for (Object theClass : classesSet) {
                List<Object> featureClassTuple = Arrays.asList(feature, theClass);
                logLikelihoods.put(featureClassTuple, 0.0);
            }
        });
        this.streamExecutor.forEach(StreamMethods.stream(trainingData.stream(), (boolean)this.isParallelized()), r -> {
            Object theClass = r.getY();
            double sumOfOccurrences = 0.0;
            for (Map.Entry entry : r.getX().entrySet()) {
                Object feature = entry.getKey();
                Double occurrences = TypeInference.toDouble(entry.getValue());
                if (occurrences == null || !(occurrences > 0.0)) continue;
                if (isBinarized) {
                    occurrences = 1.0;
                }
                List<Object> featureClassTuple = Arrays.asList(feature, theClass);
                logLikelihoods.put(featureClassTuple, (Double)logLikelihoods.get(featureClassTuple) + occurrences);
                sumOfOccurrences += occurrences.doubleValue();
            }
            Map map = totalFeatureOccurrencesForEachClass;
            synchronized (map) {
                totalFeatureOccurrencesForEachClass.put(theClass, (Double)totalFeatureOccurrencesForEachClass.get(theClass) + sumOfOccurrences);
            }
        });
        for (Map.Entry entry : logPriors.entrySet()) {
            theClass = entry.getKey();
            Double count = (Double)entry.getValue();
            logPriors.put(theClass, Math.log(count / (double)n));
        }
        this.streamExecutor.forEach(StreamMethods.stream(logLikelihoods.entrySet().stream(), (boolean)this.isParallelized()), featureClassCounts -> {
            List featureClassTuple = (List)featureClassCounts.getKey();
            Object theClass = featureClassTuple.get(1);
            Double occurrences = (Double)featureClassCounts.getValue();
            if (occurrences == null) {
                occurrences = 0.0;
            }
            Double smoothedProbability = (occurrences + 1.0) / ((Double)totalFeatureOccurrencesForEachClass.get(theClass) + (double)d);
            logLikelihoods.put(featureClassTuple, Math.log(smoothedProbability));
        });
    }

    public static abstract class AbstractTrainingParameters
    extends AbstractTrainer.AbstractTrainingParameters {
        private boolean multiProbabilityWeighted = false;

        public boolean isMultiProbabilityWeighted() {
            return this.multiProbabilityWeighted;
        }

        public void setMultiProbabilityWeighted(boolean multiProbabilityWeighted) {
            this.multiProbabilityWeighted = multiProbabilityWeighted;
        }
    }

    public static abstract class AbstractModelParameters
    extends AbstractClassifier.AbstractModelParameters {
        private Map<Object, Double> logPriors = new HashMap<Object, Double>();
        @BigMap(keyClass=List.class, valueClass=Double.class, mapType=StorageEngine.MapType.HASHMAP, storageHint=StorageEngine.StorageHint.IN_MEMORY, concurrent=true)
        private Map<List<Object>, Double> logLikelihoods;

        protected AbstractModelParameters(StorageEngine storageEngine) {
            super(storageEngine);
        }

        public Map<Object, Double> getLogPriors() {
            return this.logPriors;
        }

        protected void setLogPriors(Map<Object, Double> logPriors) {
            this.logPriors = logPriors;
        }

        public Map<List<Object>, Double> getLogLikelihoods() {
            return this.logLikelihoods;
        }

        protected void setLogLikelihoods(Map<List<Object>, Double> logLikelihoods) {
            this.logLikelihoods = logLikelihoods;
        }
    }
}

