/*
 * Decompiled with CFR 0.152.
 */
package com.datumbox.framework.core.machinelearning.featureselection;

import com.datumbox.framework.common.Configuration;
import com.datumbox.framework.common.concurrency.StreamMethods;
import com.datumbox.framework.common.storage.interfaces.StorageEngine;
import com.datumbox.framework.core.machinelearning.common.abstracts.featureselectors.AbstractCountBasedFeatureSelector;
import com.datumbox.framework.core.machinelearning.common.abstracts.featureselectors.AbstractScoreBasedFeatureSelector;
import java.util.Arrays;
import java.util.List;
import java.util.Map;

public class MutualInformation
extends AbstractCountBasedFeatureSelector<ModelParameters, TrainingParameters> {
    protected MutualInformation(TrainingParameters trainingParameters, Configuration configuration) {
        super(trainingParameters, configuration);
    }

    protected MutualInformation(String storageName, Configuration configuration) {
        super(storageName, configuration);
    }

    @Override
    protected void estimateFeatureScores(Map<Object, Double> featureScores, int N, Map<Object, Integer> classCounts, Map<List<Object>, Integer> featureClassCounts, Map<Object, Double> featureCounts) {
        this.logger.debug("estimateFeatureScores()");
        double log2 = Math.log(2.0);
        this.streamExecutor.forEach(StreamMethods.stream(featureCounts.entrySet().stream(), (boolean)this.isParallelized()), featureCount -> {
            Object feature = featureCount.getKey();
            double N1_ = (Double)featureCount.getValue();
            double N0_ = (double)N - N1_;
            double bestScore = Double.NEGATIVE_INFINITY;
            for (Map.Entry classCount : classCounts.entrySet()) {
                Object theClass = classCount.getKey();
                double N_1 = ((Integer)classCount.getValue()).intValue();
                double N_0 = (double)N - N_1;
                Integer featureClassC = (Integer)featureClassCounts.get(Arrays.asList(feature, theClass));
                double N11 = featureClassC != null ? featureClassC.doubleValue() : 0.0;
                double N01 = N_1 - N11;
                double N00 = N0_ - N01;
                double N10 = N1_ - N11;
                double scorevalue = 0.0;
                if (N11 > 0.0) {
                    scorevalue += N11 / (double)N * Math.log((double)N / N1_ * (N11 / N_1)) / log2;
                }
                if (N01 > 0.0) {
                    scorevalue += N01 / (double)N * Math.log((double)N / N0_ * (N01 / N_1)) / log2;
                }
                if (N10 > 0.0) {
                    scorevalue += N10 / (double)N * Math.log((double)N / N1_ * (N10 / N_0)) / log2;
                }
                if (N00 > 0.0) {
                    scorevalue += N00 / (double)N * Math.log((double)N / N0_ * (N00 / N_0)) / log2;
                }
                if (!(scorevalue > bestScore)) continue;
                bestScore = scorevalue;
            }
            featureScores.put(feature, bestScore);
        });
    }

    public static class TrainingParameters
    extends AbstractScoreBasedFeatureSelector.AbstractTrainingParameters {
        private static final long serialVersionUID = 1L;
    }

    public static class ModelParameters
    extends AbstractScoreBasedFeatureSelector.AbstractModelParameters {
        private static final long serialVersionUID = 1L;

        protected ModelParameters(StorageEngine storageEngine) {
            super(storageEngine);
        }
    }
}

