/*
 * Decompiled with CFR 0.152.
 */
package com.datumbox.framework.core.machinelearning.classification;

import com.datumbox.framework.common.Configuration;
import com.datumbox.framework.common.concurrency.StreamMethods;
import com.datumbox.framework.common.dataobjects.AssociativeArray;
import com.datumbox.framework.common.dataobjects.TypeInference;
import com.datumbox.framework.common.storage.interfaces.StorageEngine;
import com.datumbox.framework.core.common.dataobjects.Dataframe;
import com.datumbox.framework.core.common.dataobjects.Record;
import com.datumbox.framework.core.machinelearning.common.abstracts.algorithms.AbstractNaiveBayes;
import com.datumbox.framework.core.machinelearning.common.interfaces.PredictParallelizable;
import com.datumbox.framework.core.statistics.descriptivestatistics.Descriptives;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;

public class BernoulliNaiveBayes
extends AbstractNaiveBayes<ModelParameters, TrainingParameters> {
    protected BernoulliNaiveBayes(TrainingParameters trainingParameters, Configuration configuration) {
        super(trainingParameters, configuration);
    }

    protected BernoulliNaiveBayes(String storageName, Configuration configuration) {
        super(storageName, configuration);
    }

    @Override
    protected boolean isBinarized() {
        return true;
    }

    @Override
    public PredictParallelizable.Prediction _predictRecord(Record r) {
        ModelParameters modelParameters = (ModelParameters)this.knowledgeBase.getModelParameters();
        Map<List<Object>, Double> logLikelihoods = modelParameters.getLogLikelihoods();
        Map<Object, Double> logPriors = modelParameters.getLogPriors();
        Set<Object> classesSet = modelParameters.getClasses();
        Map<Object, Double> sumOfLog1minusProb = modelParameters.getSumOfLog1minusProb();
        Object someClass = classesSet.iterator().next();
        AssociativeArray predictionScores = new AssociativeArray(new HashMap<Object, Double>(logPriors));
        for (Map.Entry<Object, Double> entry : sumOfLog1minusProb.entrySet()) {
            Object theClass = entry.getKey();
            Double value = entry.getValue();
            Double previousValue = predictionScores.getDouble(theClass);
            predictionScores.put(theClass, (Object)(previousValue + value));
        }
        for (Map.Entry<Object, Double> entry : r.getX().entrySet()) {
            Object feature = entry.getKey();
            if (!logLikelihoods.containsKey(Arrays.asList(feature, someClass))) continue;
            AssociativeArray classLogScoresForThisFeature = new AssociativeArray();
            for (Object e : classesSet) {
                Double logScore = logLikelihoods.get(Arrays.asList(feature, e));
                classLogScoresForThisFeature.put(e, (Object)logScore);
            }
            Double occurrences = TypeInference.toDouble((Object)entry.getValue());
            if (occurrences == null || occurrences == 0.0) continue;
            for (Map.Entry entry2 : classLogScoresForThisFeature.entrySet()) {
                Object theClass = entry2.getKey();
                Double probability = TypeInference.toDouble(entry2.getValue());
                Double previousValue = predictionScores.getDouble(theClass);
                predictionScores.put(theClass, (Object)(previousValue + Math.log(probability) - Math.log(1.0 - probability)));
            }
        }
        Object predictedClass = this.getSelectedClassFromClassScores(predictionScores);
        Descriptives.normalizeExp(predictionScores);
        return new PredictParallelizable.Prediction(predictedClass, predictionScores);
    }

    @Override
    protected void _fit(Dataframe trainingData) {
        Object theClass;
        ModelParameters modelParameters = (ModelParameters)this.knowledgeBase.getModelParameters();
        int n = trainingData.size();
        int d = trainingData.xColumnSize();
        ((TrainingParameters)this.knowledgeBase.getTrainingParameters()).setMultiProbabilityWeighted(false);
        Map<List<Object>, Double> likelihoods = modelParameters.getLogLikelihoods();
        Map<Object, Double> logPriors = modelParameters.getLogPriors();
        Set<Object> classesSet = modelParameters.getClasses();
        Map<Object, Double> sumOfLog1minusProb = modelParameters.getSumOfLog1minusProb();
        HashMap<Object, Integer> totalFeatureOccurrencesForEachClass = new HashMap<Object, Integer>();
        for (Record record : trainingData) {
            theClass = record.getY();
            if (classesSet.add(theClass)) {
                logPriors.put(theClass, 1.0);
                totalFeatureOccurrencesForEachClass.put(theClass, 0);
                sumOfLog1minusProb.put(theClass, 0.0);
                continue;
            }
            logPriors.put(theClass, logPriors.get(theClass) + 1.0);
        }
        this.streamExecutor.forEach(StreamMethods.stream(trainingData.getXDataTypes().keySet().stream(), (boolean)this.isParallelized()), feature -> {
            for (Object theClass : classesSet) {
                List<Object> featureClassTuple = Arrays.asList(feature, theClass);
                likelihoods.put(featureClassTuple, 0.0);
            }
        });
        this.streamExecutor.forEach(StreamMethods.stream(trainingData.stream(), (boolean)this.isParallelized()), r -> {
            Object theClass = r.getY();
            int sumOfOccurrences = 0;
            for (Map.Entry entry : r.getX().entrySet()) {
                Object feature = entry.getKey();
                Double occurrences = TypeInference.toDouble(entry.getValue());
                if (occurrences == null || !(occurrences > 0.0)) continue;
                List<Object> featureClassTuple = Arrays.asList(feature, theClass);
                likelihoods.put(featureClassTuple, (Double)likelihoods.get(featureClassTuple) + 1.0);
                ++sumOfOccurrences;
            }
            Map map = totalFeatureOccurrencesForEachClass;
            synchronized (map) {
                totalFeatureOccurrencesForEachClass.put(theClass, (Integer)totalFeatureOccurrencesForEachClass.get(theClass) + sumOfOccurrences);
            }
        });
        for (Map.Entry entry : logPriors.entrySet()) {
            theClass = entry.getKey();
            Double count = (Double)entry.getValue();
            logPriors.put(theClass, Math.log(count / (double)n));
        }
        for (Object object : classesSet) {
            double sumLog1minusP = this.streamExecutor.sum(StreamMethods.stream(trainingData.getXDataTypes().keySet().stream(), (boolean)this.isParallelized()).mapToDouble(feature -> {
                List<Object> featureClassTuple = Arrays.asList(feature, theClass2);
                Double occurrences = (Double)likelihoods.get(featureClassTuple);
                Double smoothedProbability = (occurrences + 1.0) / (double)((Integer)totalFeatureOccurrencesForEachClass.get(theClass2) + d);
                likelihoods.put(featureClassTuple, smoothedProbability);
                double log1minusP = Math.log(1.0 - smoothedProbability);
                return log1minusP;
            }));
            sumOfLog1minusProb.put(object, sumOfLog1minusProb.get(object) + sumLog1minusP);
        }
    }

    public static class TrainingParameters
    extends AbstractNaiveBayes.AbstractTrainingParameters {
        private static final long serialVersionUID = 1L;
    }

    public static class ModelParameters
    extends AbstractNaiveBayes.AbstractModelParameters {
        private static final long serialVersionUID = 1L;
        private Map<Object, Double> sumOfLog1minusProb = new HashMap<Object, Double>();

        protected ModelParameters(StorageEngine storageEngine) {
            super(storageEngine);
        }

        public Map<Object, Double> getSumOfLog1minusProb() {
            return this.sumOfLog1minusProb;
        }

        protected void setSumOfLog1minusProb(Map<Object, Double> sumOfLog1minusProb) {
            this.sumOfLog1minusProb = sumOfLog1minusProb;
        }
    }
}

