/*
 * Decompiled with CFR 0.152.
 */
package com.datumbox.framework.core.machinelearning.common.abstracts.algorithms;

import com.datumbox.framework.common.Configuration;
import com.datumbox.framework.common.dataobjects.AssociativeArray;
import com.datumbox.framework.common.dataobjects.DataTable2D;
import com.datumbox.framework.common.dataobjects.FlatDataList;
import com.datumbox.framework.common.storage.interfaces.StorageEngine;
import com.datumbox.framework.core.common.dataobjects.Dataframe;
import com.datumbox.framework.core.common.dataobjects.Record;
import com.datumbox.framework.core.common.utilities.MapMethods;
import com.datumbox.framework.core.machinelearning.MLBuilder;
import com.datumbox.framework.core.machinelearning.common.abstracts.AbstractTrainer;
import com.datumbox.framework.core.machinelearning.common.abstracts.modelers.AbstractClassifier;
import com.datumbox.framework.core.machinelearning.common.dataobjects.TrainableBundle;
import com.datumbox.framework.core.machinelearning.common.interfaces.Trainable;
import com.datumbox.framework.core.machinelearning.ensemblelearning.FixedCombinationRules;
import com.datumbox.framework.core.statistics.descriptivestatistics.Descriptives;
import com.datumbox.framework.core.statistics.sampling.SimpleRandomSampling;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Set;

public abstract class AbstractBoostingBagging<MP extends AbstractModelParameters, TP extends AbstractTrainingParameters>
extends AbstractClassifier<MP, TP> {
    private final TrainableBundle bundle;
    private static final String STORAGE_INDICATOR = "Cmp";
    private static final int MAX_NUM_OF_RETRIES = 2;

    protected AbstractBoostingBagging(TP trainingParameters, Configuration configuration) {
        super(trainingParameters, configuration);
        this.bundle = new TrainableBundle(configuration.getStorageConfiguration().getStorageNameSeparator());
    }

    protected AbstractBoostingBagging(String storageName, Configuration configuration) {
        super(storageName, configuration);
        this.bundle = new TrainableBundle(configuration.getStorageConfiguration().getStorageNameSeparator());
    }

    @Override
    protected void _predict(Dataframe newData) {
        this.initBundle();
        List<Double> weakClassifierWeights = ((AbstractModelParameters)this.knowledgeBase.getModelParameters()).getWeakClassifierWeights();
        StorageEngine storageEngine = this.knowledgeBase.getStorageEngine();
        Map tmp_recordDecisions = storageEngine.getBigMap("tmp_recordDecisions", Object.class, DataTable2D.class, StorageEngine.MapType.HASHMAP, StorageEngine.StorageHint.IN_DISK, false, true);
        for (Integer rId : newData.index()) {
            tmp_recordDecisions.put(rId, new DataTable2D());
        }
        AssociativeArray classifierWeightsArray = new AssociativeArray();
        int totalWeakClassifiers = weakClassifierWeights.size();
        for (int i = 0; i < totalWeakClassifiers; ++i) {
            AbstractClassifier mlclassifier = (AbstractClassifier)this.bundle.get(STORAGE_INDICATOR + i);
            mlclassifier.predict(newData);
            classifierWeightsArray.put((Object)i, (Object)weakClassifierWeights.get(i));
            for (Map.Entry<Integer, Record> e : newData.entries()) {
                Integer rId = e.getKey();
                Record r = e.getValue();
                AssociativeArray classProbabilities = r.getYPredictedProbabilities();
                DataTable2D rDecisions = (DataTable2D)tmp_recordDecisions.get(rId);
                rDecisions.put((Object)i, classProbabilities);
                tmp_recordDecisions.put(rId, rDecisions);
            }
        }
        for (Map.Entry<Integer, Record> e : newData.entries()) {
            Integer rId = e.getKey();
            Record r = e.getValue();
            AssociativeArray combinedClassVotes = FixedCombinationRules.weightedAverage((DataTable2D)tmp_recordDecisions.get(rId), classifierWeightsArray);
            Descriptives.normalize(combinedClassVotes);
            newData._unsafe_set(rId, new Record(r.getX(), r.getY(), MapMethods.selectMaxKeyValue(combinedClassVotes).getKey(), combinedClassVotes));
        }
        storageEngine.dropBigMap("tmp_recordDecisions", tmp_recordDecisions);
    }

    @Override
    protected void _fit(Dataframe trainingData) {
        Configuration configuration = this.knowledgeBase.getConfiguration();
        AbstractTrainingParameters trainingParameters = (AbstractTrainingParameters)this.knowledgeBase.getTrainingParameters();
        AbstractModelParameters modelParameters = (AbstractModelParameters)this.knowledgeBase.getModelParameters();
        this.resetBundle();
        int n = trainingData.size();
        Set<Object> classesSet = modelParameters.getClasses();
        for (Object r : trainingData) {
            Object theClass = ((Record)r).getY();
            classesSet.add(theClass);
        }
        AssociativeArray observationWeights = new AssociativeArray();
        for (Integer rId : trainingData.index()) {
            observationWeights.put((Object)rId, (Object)(1.0 / (double)n));
        }
        AbstractTrainer.AbstractTrainingParameters weakClassifierTrainingParameters = trainingParameters.getWeakClassifierTrainingParameters();
        int totalWeakClassifiers = trainingParameters.getMaxWeakClassifiers();
        int i = 0;
        int retryCounter = 0;
        while (i < totalWeakClassifiers) {
            this.logger.debug("Training Weak learner {}", (Object)i);
            FlatDataList sampledIDs = SimpleRandomSampling.weightedSampling(observationWeights, n, true).toFlatDataList();
            Dataframe sampledTrainingDataset = trainingData.getSubset(sampledIDs);
            AbstractClassifier mlclassifier = (AbstractClassifier)MLBuilder.create(weakClassifierTrainingParameters, configuration);
            mlclassifier.fit(sampledTrainingDataset);
            sampledTrainingDataset.close();
            mlclassifier.predict(trainingData);
            Status status = this.updateObservationAndClassifierWeights(trainingData, observationWeights);
            if (status == Status.IGNORE) {
                mlclassifier.close();
            } else {
                this.bundle.put(STORAGE_INDICATOR + i, mlclassifier);
            }
            if (status == Status.STOP) {
                this.logger.debug("Skipping further training due to low error");
                break;
            }
            if (status == Status.IGNORE) {
                if (retryCounter < 2) {
                    this.logger.debug("Ignoring last weak learner due to high error");
                    ++retryCounter;
                    continue;
                }
                this.logger.debug("Too many retries, skipping further training");
                break;
            }
            if (status == Status.NEXT) {
                retryCounter = 0;
            }
            ++i;
        }
    }

    protected abstract Status updateObservationAndClassifierWeights(Dataframe var1, AssociativeArray var2);

    @Override
    public void save(String storageName) {
        this.initBundle();
        super.save(storageName);
        String knowledgeBaseName = this.createKnowledgeBaseName(storageName, this.knowledgeBase.getConfiguration().getStorageConfiguration().getStorageNameSeparator());
        this.bundle.save(knowledgeBaseName);
    }

    @Override
    public void delete() {
        this.initBundle();
        this.bundle.delete();
        super.delete();
    }

    @Override
    public void close() {
        this.initBundle();
        this.bundle.close();
        super.close();
    }

    private void resetBundle() {
        this.bundle.delete();
    }

    private void initBundle() {
        Configuration configuration = this.knowledgeBase.getConfiguration();
        StorageEngine storageEngine = this.knowledgeBase.getStorageEngine();
        AbstractModelParameters modelParameters = (AbstractModelParameters)this.knowledgeBase.getModelParameters();
        AbstractTrainingParameters trainingParameters = (AbstractTrainingParameters)this.knowledgeBase.getTrainingParameters();
        String separator = configuration.getStorageConfiguration().getStorageNameSeparator();
        Class weakClassifierClass = trainingParameters.getWeakClassifierTrainingParameters().getTClass();
        int totalWeakClassifiers = Math.min(modelParameters.getWeakClassifierWeights().size(), trainingParameters.getMaxWeakClassifiers());
        for (int i = 0; i < totalWeakClassifiers; ++i) {
            String key = STORAGE_INDICATOR + i;
            if (this.bundle.containsKey(key)) continue;
            this.bundle.put(key, (Trainable)MLBuilder.load(weakClassifierClass, storageEngine.getStorageName() + separator + key, configuration));
        }
    }

    protected static enum Status {
        NEXT,
        STOP,
        IGNORE;

    }

    public static abstract class AbstractTrainingParameters
    extends AbstractTrainer.AbstractTrainingParameters {
        private int maxWeakClassifiers = 5;
        private AbstractTrainer.AbstractTrainingParameters weakClassifierTrainingParameters;

        public int getMaxWeakClassifiers() {
            return this.maxWeakClassifiers;
        }

        public void setMaxWeakClassifiers(int maxWeakClassifiers) {
            this.maxWeakClassifiers = maxWeakClassifiers;
        }

        public AbstractTrainer.AbstractTrainingParameters getWeakClassifierTrainingParameters() {
            return this.weakClassifierTrainingParameters;
        }

        public void setWeakClassifierTrainingParameters(AbstractTrainer.AbstractTrainingParameters weakClassifierTrainingParameters) {
            this.weakClassifierTrainingParameters = weakClassifierTrainingParameters;
        }
    }

    public static abstract class AbstractModelParameters
    extends AbstractClassifier.AbstractModelParameters {
        private List<Double> weakClassifierWeights = new ArrayList<Double>();

        protected AbstractModelParameters(StorageEngine storageEngine) {
            super(storageEngine);
        }

        public List<Double> getWeakClassifierWeights() {
            return this.weakClassifierWeights;
        }

        protected void setWeakClassifierWeights(List<Double> weakClassifierWeights) {
            this.weakClassifierWeights = weakClassifierWeights;
        }
    }
}

