/*
 * Decompiled with CFR 0.152.
 */
package com.datumbox.framework.core.machinelearning.common.abstracts.algorithms;

import com.datumbox.framework.common.Configuration;
import com.datumbox.framework.common.dataobjects.AssociativeArray;
import com.datumbox.framework.common.storage.interfaces.BigMap;
import com.datumbox.framework.common.storage.interfaces.StorageEngine;
import com.datumbox.framework.core.common.dataobjects.Dataframe;
import com.datumbox.framework.core.common.dataobjects.Record;
import com.datumbox.framework.core.common.utilities.MapMethods;
import com.datumbox.framework.core.common.utilities.PHPMethods;
import com.datumbox.framework.core.machinelearning.common.abstracts.AbstractTrainer;
import com.datumbox.framework.core.machinelearning.common.abstracts.modelers.AbstractClusterer;
import com.datumbox.framework.core.machinelearning.common.interfaces.PredictParallelizable;
import com.datumbox.framework.core.statistics.descriptivestatistics.Descriptives;
import com.datumbox.framework.core.statistics.sampling.SimpleRandomSampling;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import java.util.Set;

public abstract class AbstractDPMM<CL extends AbstractCluster, MP extends AbstractModelParameters, TP extends AbstractTrainingParameters>
extends AbstractClusterer<CL, MP, TP>
implements PredictParallelizable {
    private boolean parallelized = true;

    protected AbstractDPMM(TP trainingParameters, Configuration configuration) {
        super(trainingParameters, configuration);
    }

    protected AbstractDPMM(String storageName, Configuration configuration) {
        super(storageName, configuration);
    }

    @Override
    public boolean isParallelized() {
        return this.parallelized;
    }

    @Override
    public void setParallelized(boolean parallelized) {
        this.parallelized = parallelized;
    }

    @Override
    protected void _predict(Dataframe newData) {
        this._predictDatasetParallel(newData, this.knowledgeBase.getStorageEngine(), this.knowledgeBase.getConfiguration().getConcurrencyConfiguration());
    }

    @Override
    public PredictParallelizable.Prediction _predictRecord(Record r) {
        AbstractModelParameters modelParameters = (AbstractModelParameters)this.knowledgeBase.getModelParameters();
        Map clusterMap = modelParameters.getClusterMap();
        AssociativeArray clusterScores = new AssociativeArray();
        for (Integer clusterId : clusterMap.keySet()) {
            Object c = this.getFromClusterMap(clusterId, clusterMap);
            double probability = ((AbstractCluster)c).posteriorLogPdf(r);
            clusterScores.put((Object)clusterId, (Object)probability);
        }
        Descriptives.normalizeExp(clusterScores);
        return new PredictParallelizable.Prediction(this.getSelectedClusterFromScores(clusterScores), clusterScores);
    }

    @Override
    protected void _fit(Dataframe trainingData) {
        AbstractModelParameters modelParameters = (AbstractModelParameters)this.knowledgeBase.getModelParameters();
        modelParameters.setD(trainingData.xColumnSize());
        Set<Object> goldStandardClasses = modelParameters.getGoldStandardClasses();
        Map<Object, Integer> featureIds = modelParameters.getFeatureIds();
        int previousFeatureId = 0;
        for (Record r : trainingData) {
            Object theClass = r.getY();
            if (theClass != null) {
                goldStandardClasses.add(theClass);
            }
            for (Map.Entry entry : r.getX().entrySet()) {
                Object feature = entry.getKey();
                if (featureIds.putIfAbsent(feature, previousFeatureId) != null) continue;
                ++previousFeatureId;
            }
        }
        int totalIterations = this.collapsedGibbsSampling(trainingData);
        modelParameters.setTotalIterations(totalIterations);
        this.clearClusters();
    }

    private CL getFromClusterMap(int clusterId, Map<Integer, CL> clusterMap) {
        AbstractCluster c = (AbstractCluster)clusterMap.get(clusterId);
        if (c.getFeatureIds() == null) {
            c.setFeatureIds(((AbstractModelParameters)this.knowledgeBase.getModelParameters()).getFeatureIds());
        }
        return (CL)c;
    }

    /*
     * WARNING - void declaration
     */
    private int collapsedGibbsSampling(Dataframe dataset) {
        void var11_20;
        AbstractModelParameters modelParameters = (AbstractModelParameters)this.knowledgeBase.getModelParameters();
        Map clusterMap = modelParameters.getClusterMap();
        AbstractTrainingParameters trainingParameters = (AbstractTrainingParameters)this.knowledgeBase.getTrainingParameters();
        double alpha = trainingParameters.getAlpha();
        Integer newClusterId = clusterMap.size();
        if (trainingParameters.getInitializationMethod() == AbstractTrainingParameters.Initialization.ONE_CLUSTER_PER_RECORD) {
            for (Map.Entry<Integer, Record> e : dataset.entries()) {
                Integer rId = e.getKey();
                Record record = e.getValue();
                CL cluster = this.createNewCluster(newClusterId);
                ((AbstractCluster)cluster).add(record);
                clusterMap.put(newClusterId, cluster);
                Record record2 = new Record(record.getX(), record.getY(), newClusterId, record.getYPredictedProbabilities());
                dataset._unsafe_set(rId, record2);
                newClusterId = newClusterId + 1;
            }
        } else {
            int numberOfNewClusters = (int)(alpha * Math.log(dataset.size()));
            if (numberOfNewClusters <= 0) {
                numberOfNewClusters = 1;
            }
            for (int i = 0; i < numberOfNewClusters; ++i) {
                CL cluster = this.createNewCluster(newClusterId);
                clusterMap.put(newClusterId, cluster);
                newClusterId = newClusterId + 1;
            }
            int clusterMapSize = newClusterId;
            for (Map.Entry entry : dataset.entries()) {
                Integer rId = (Integer)entry.getKey();
                Record r = (Record)entry.getValue();
                Integer assignedClusterId = PHPMethods.mt_rand(0, clusterMapSize - 1);
                r = new Record(r.getX(), r.getY(), assignedClusterId, r.getYPredictedProbabilities());
                dataset._unsafe_set(rId, r);
                Object c = this.getFromClusterMap(assignedClusterId, clusterMap);
                ((AbstractCluster)c).add(r);
                clusterMap.put(assignedClusterId, c);
            }
        }
        int n = clusterMap.size();
        int maxIterations = trainingParameters.getMaxIterations();
        boolean noChangeMade = false;
        boolean bl = false;
        while (var11_20 < maxIterations && !noChangeMade) {
            this.logger.debug("Iteration {}", (Object)((int)var11_20));
            noChangeMade = true;
            for (Map.Entry<Integer, Record> e : dataset.entries()) {
                Integer rId = e.getKey();
                Record r = e.getValue();
                Integer pointClusterId = (Integer)r.getYPredicted();
                Object ci = this.getFromClusterMap(pointClusterId, clusterMap);
                ((AbstractCluster)ci).remove(r);
                if (((AbstractClusterer.AbstractCluster)ci).size() == 0) {
                    clusterMap.remove(pointClusterId);
                } else {
                    clusterMap.put(pointClusterId, ci);
                }
                AssociativeArray condProbCiGivenXiAndOtherCi = this.clusterProbabilities(r, n, clusterMap);
                CL cNew = this.createNewCluster(newClusterId);
                double priorLogPredictive = ((AbstractCluster)cNew).posteriorLogPdf(r);
                double probNewCluster = alpha / (alpha + (double)n - 1.0);
                condProbCiGivenXiAndOtherCi.put((Object)newClusterId, (Object)(priorLogPredictive + Math.log(probNewCluster)));
                Descriptives.normalizeExp(condProbCiGivenXiAndOtherCi);
                Integer sampledClusterId = (Integer)SimpleRandomSampling.weightedSampling(condProbCiGivenXiAndOtherCi, 1, true).iterator().next();
                if (Objects.equals(sampledClusterId, newClusterId)) {
                    r = new Record(r.getX(), r.getY(), newClusterId, r.getYPredictedProbabilities());
                    dataset._unsafe_set(rId, r);
                    ((AbstractCluster)cNew).add(r);
                    clusterMap.put(newClusterId, cNew);
                    noChangeMade = false;
                    newClusterId = newClusterId + 1;
                    continue;
                }
                if (!Objects.equals(pointClusterId, sampledClusterId)) {
                    r = new Record(r.getX(), r.getY(), sampledClusterId, r.getYPredictedProbabilities());
                    dataset._unsafe_set(rId, r);
                    noChangeMade = false;
                }
                Object c = this.getFromClusterMap(sampledClusterId, clusterMap);
                ((AbstractCluster)c).add(r);
                clusterMap.put(sampledClusterId, c);
            }
            ++var11_20;
        }
        return (int)var11_20;
    }

    private AssociativeArray clusterProbabilities(Record r, int n, Map<Integer, CL> clusterMap) {
        HashMap<Integer, Double> condProbCiGivenXiAndOtherCi = new HashMap<Integer, Double>();
        double alpha = ((AbstractTrainingParameters)this.knowledgeBase.getTrainingParameters()).getAlpha();
        for (Integer clusterId : clusterMap.keySet()) {
            CL ck = this.getFromClusterMap(clusterId, clusterMap);
            double marginalLogLikelihoodXi = ((AbstractCluster)ck).posteriorLogPdf(r);
            double mixingXi = (double)((AbstractClusterer.AbstractCluster)ck).size() / (alpha + (double)n - 1.0);
            condProbCiGivenXiAndOtherCi.put(clusterId, marginalLogLikelihoodXi + Math.log(mixingXi));
        }
        return new AssociativeArray(condProbCiGivenXiAndOtherCi);
    }

    private Object getSelectedClusterFromScores(AssociativeArray clusterScores) {
        Map.Entry<Object, Object> maxEntry = MapMethods.selectMaxKeyValue(clusterScores);
        return maxEntry.getKey();
    }

    protected abstract CL createNewCluster(Integer var1);

    public static abstract class AbstractTrainingParameters
    extends AbstractTrainer.AbstractTrainingParameters {
        private double alpha;
        private int maxIterations = 1000;
        private Initialization initializationMethod = Initialization.ONE_CLUSTER_PER_RECORD;

        public double getAlpha() {
            return this.alpha;
        }

        public void setAlpha(double alpha) {
            this.alpha = alpha;
        }

        public int getMaxIterations() {
            return this.maxIterations;
        }

        public void setMaxIterations(int maxIterations) {
            this.maxIterations = maxIterations;
        }

        public Initialization getInitializationMethod() {
            return this.initializationMethod;
        }

        public void setInitializationMethod(Initialization initializationMethod) {
            this.initializationMethod = initializationMethod;
        }

        public static enum Initialization {
            ONE_CLUSTER_PER_RECORD,
            RANDOM_ASSIGNMENT;

        }
    }

    public static abstract class AbstractModelParameters<CL extends AbstractCluster>
    extends AbstractClusterer.AbstractModelParameters<CL> {
        private Integer d = 0;
        private int totalIterations;
        @BigMap(keyClass=Object.class, valueClass=Integer.class, mapType=StorageEngine.MapType.HASHMAP, storageHint=StorageEngine.StorageHint.IN_MEMORY, concurrent=false)
        private Map<Object, Integer> featureIds;

        protected AbstractModelParameters(StorageEngine storageEngine) {
            super(storageEngine);
        }

        public Integer getD() {
            return this.d;
        }

        protected void setD(Integer d) {
            this.d = d;
        }

        public int getTotalIterations() {
            return this.totalIterations;
        }

        protected void setTotalIterations(int totalIterations) {
            this.totalIterations = totalIterations;
        }

        public Map<Object, Integer> getFeatureIds() {
            return this.featureIds;
        }

        protected void setFeatureIds(Map<Object, Integer> featureIds) {
            this.featureIds = featureIds;
        }
    }

    public static abstract class AbstractCluster
    extends AbstractClusterer.AbstractCluster {
        protected transient Map<Object, Integer> featureIds;

        protected AbstractCluster(Integer clusterId) {
            super(clusterId);
        }

        protected abstract Map<Object, Integer> getFeatureIds();

        protected abstract void setFeatureIds(Map<Object, Integer> var1);

        protected abstract void updateClusterParameters();

        protected abstract double posteriorLogPdf(Record var1);

        @Override
        protected abstract void add(Record var1);

        @Override
        protected abstract void remove(Record var1);
    }
}

