/*
 * Decompiled with CFR 0.152.
 */
package com.datumbox.framework.core.machinelearning.topicmodeling;

import com.datumbox.framework.common.Configuration;
import com.datumbox.framework.common.dataobjects.AssociativeArray;
import com.datumbox.framework.common.dataobjects.AssociativeArray2D;
import com.datumbox.framework.common.dataobjects.TypeInference;
import com.datumbox.framework.common.storage.interfaces.BigMap;
import com.datumbox.framework.common.storage.interfaces.StorageEngine;
import com.datumbox.framework.core.common.dataobjects.Dataframe;
import com.datumbox.framework.core.common.dataobjects.Record;
import com.datumbox.framework.core.common.utilities.MapMethods;
import com.datumbox.framework.core.common.utilities.PHPMethods;
import com.datumbox.framework.core.machinelearning.common.abstracts.AbstractTrainer;
import com.datumbox.framework.core.machinelearning.common.abstracts.modelers.AbstractTopicModeler;
import com.datumbox.framework.core.statistics.descriptivestatistics.Descriptives;
import com.datumbox.framework.core.statistics.sampling.SimpleRandomSampling;
import java.util.Arrays;
import java.util.List;
import java.util.Map;

public class LatentDirichletAllocation
extends AbstractTopicModeler<ModelParameters, TrainingParameters> {
    protected LatentDirichletAllocation(TrainingParameters trainingParameters, Configuration configuration) {
        super(trainingParameters, configuration);
    }

    protected LatentDirichletAllocation(String storageName, Configuration configuration) {
        super(storageName, configuration);
    }

    public AssociativeArray2D getWordProbabilitiesPerTopic() {
        AssociativeArray2D ptw = new AssociativeArray2D();
        ModelParameters modelParameters = (ModelParameters)this.knowledgeBase.getModelParameters();
        TrainingParameters trainingParameters = (TrainingParameters)this.knowledgeBase.getTrainingParameters();
        int k = trainingParameters.getK();
        for (int topicId = 0; topicId < k; ++topicId) {
            ptw.put((Object)topicId, new AssociativeArray());
        }
        int d = modelParameters.getD();
        double beta = trainingParameters.getBeta();
        Map<List<Object>, Integer> topicWordCounts = modelParameters.getTopicWordCounts();
        Map<Integer, Integer> topicCounts = modelParameters.getTopicCounts();
        for (Map.Entry<List<Object>, Integer> entry : topicWordCounts.entrySet()) {
            List<Object> tpk = entry.getKey();
            Integer topicId = (Integer)tpk.get(0);
            Object word = tpk.get(1);
            Integer njw = entry.getValue();
            Integer nj = topicCounts.get(topicId);
            double probability = ((double)njw.intValue() + beta) / ((double)nj.intValue() + beta * (double)d);
            ptw.get((Object)topicId).put(word, (Object)probability);
        }
        for (int topicId = 0; topicId < k; ++topicId) {
            ptw.put((Object)topicId, MapMethods.sortAssociativeArrayByValueDescending(ptw.get((Object)topicId)));
        }
        return ptw;
    }

    @Override
    protected void _fit(Dataframe trainingData) {
        int iteration;
        ModelParameters modelParameters = (ModelParameters)this.knowledgeBase.getModelParameters();
        modelParameters.setD(trainingData.xColumnSize());
        int d = modelParameters.getD();
        TrainingParameters trainingParameters = (TrainingParameters)this.knowledgeBase.getTrainingParameters();
        int k = trainingParameters.getK();
        Map<List<Object>, Integer> topicAssignmentOfDocumentWord = modelParameters.getTopicAssignmentOfDocumentWord();
        Map<List<Integer>, Integer> documentTopicCounts = modelParameters.getDocumentTopicCounts();
        Map<List<Object>, Integer> topicWordCounts = modelParameters.getTopicWordCounts();
        Map<Integer, Integer> documentWordCounts = modelParameters.getDocumentWordCounts();
        Map<Integer, Integer> topicCounts = modelParameters.getTopicCounts();
        for (Map.Entry<Integer, Record> e : trainingData.entries()) {
            Integer rId = e.getKey();
            Record r = e.getValue();
            Integer documentId = rId;
            documentWordCounts.put(documentId, r.getX().size());
            for (Map.Entry entry : r.getX().entrySet()) {
                Object wordPosition = entry.getKey();
                Object word = entry.getValue();
                Integer topic = PHPMethods.mt_rand(0, k - 1);
                this.increase(topicCounts, topic);
                topicAssignmentOfDocumentWord.put(Arrays.asList(documentId, wordPosition), topic);
                this.increase(documentTopicCounts, Arrays.asList(documentId, topic));
                this.increase(topicWordCounts, Arrays.asList(topic, word));
            }
        }
        double alpha = trainingParameters.getAlpha();
        double beta = trainingParameters.getBeta();
        int maxIterations = trainingParameters.getMaxIterations();
        for (iteration = 0; iteration < maxIterations; ++iteration) {
            this.logger.debug("Iteration {}", (Object)iteration);
            int changedCounter = 0;
            for (Map.Entry<Integer, Record> e : trainingData.entries()) {
                Integer rId = e.getKey();
                Record r = e.getValue();
                Integer documentId = rId;
                AssociativeArray topicAssignments = new AssociativeArray();
                for (int j = 0; j < k; ++j) {
                    topicAssignments.put((Object)j, (Object)0.0);
                }
                int totalWords = r.getX().size();
                for (Map.Entry entry : r.getX().entrySet()) {
                    Integer newTopic;
                    Object wordPosition = entry.getKey();
                    Object word = entry.getValue();
                    Integer topic = topicAssignmentOfDocumentWord.get(Arrays.asList(documentId, wordPosition));
                    this.decrease(topicCounts, topic);
                    this.decrease(documentTopicCounts, Arrays.asList(documentId, topic));
                    this.decrease(topicWordCounts, Arrays.asList(topic, word));
                    AssociativeArray topicProbabilities = new AssociativeArray();
                    for (int j = 0; j < k; ++j) {
                        Integer njw = topicWordCounts.get(Arrays.asList(j, word));
                        double enumerator = njw != null ? (double)njw.intValue() + beta : beta;
                        Integer njd = documentTopicCounts.get(Arrays.asList(documentId, j));
                        enumerator = njd != null ? (enumerator *= (double)njd.intValue() + alpha) : (enumerator *= alpha);
                        double denominator = (double)topicCounts.get(j).intValue() + beta * (double)d;
                        topicProbabilities.put((Object)j, (Object)(enumerator / denominator));
                    }
                    topic = newTopic = (Integer)SimpleRandomSampling.weightedSampling(topicProbabilities, 1, true).iterator().next();
                    topicAssignmentOfDocumentWord.put(Arrays.asList(documentId, wordPosition), topic);
                    this.increase(topicCounts, topic);
                    this.increase(documentTopicCounts, Arrays.asList(documentId, topic));
                    this.increase(topicWordCounts, Arrays.asList(topic, word));
                    topicAssignments.put((Object)topic, (Object)(TypeInference.toDouble((Object)topicAssignments.get((Object)topic)) + 1.0 / (double)totalWords));
                }
                Object mainTopic = MapMethods.selectMaxKeyValue(topicAssignments).getKey();
                if (!mainTopic.equals(r.getYPredicted())) {
                    ++changedCounter;
                }
                trainingData._unsafe_set(rId, new Record(r.getX(), r.getY(), mainTopic, topicAssignments));
            }
            this.logger.debug("Reassigned Records {}", (Object)changedCounter);
            if (changedCounter != 0) continue;
            break;
        }
        modelParameters.setTotalIterations(iteration);
    }

    private <K> void increase(Map<K, Integer> map, K key) {
        map.put(key, map.getOrDefault(key, 0) + 1);
    }

    private <K> void decrease(Map<K, Integer> map, K key) {
        map.put(key, map.getOrDefault(key, 0) - 1);
    }

    @Override
    protected void _predict(Dataframe newData) {
        ModelParameters modelParameters = (ModelParameters)this.knowledgeBase.getModelParameters();
        TrainingParameters trainingParameters = (TrainingParameters)this.knowledgeBase.getTrainingParameters();
        int d = modelParameters.getD();
        int k = trainingParameters.getK();
        Map<List<Object>, Integer> topicWordCounts = modelParameters.getTopicWordCounts();
        Map<Integer, Integer> topicCounts = modelParameters.getTopicCounts();
        StorageEngine storageEngine = this.knowledgeBase.getStorageEngine();
        Map tmp_topicAssignmentOfDocumentWord = storageEngine.getBigMap("tmp_topicAssignmentOfDocumentWord", List.class, Integer.class, StorageEngine.MapType.HASHMAP, StorageEngine.StorageHint.IN_CACHE, false, true);
        Map tmp_documentTopicCounts = storageEngine.getBigMap("tmp_documentTopicCounts", List.class, Integer.class, StorageEngine.MapType.HASHMAP, StorageEngine.StorageHint.IN_MEMORY, false, true);
        Map tmp_topicWordCounts = storageEngine.getBigMap("tmp_topicWordCounts", List.class, Integer.class, StorageEngine.MapType.HASHMAP, StorageEngine.StorageHint.IN_CACHE, false, true);
        Map tmp_topicCounts = storageEngine.getBigMap("tmp_topicCounts", Integer.class, Integer.class, StorageEngine.MapType.HASHMAP, StorageEngine.StorageHint.IN_MEMORY, false, true);
        for (Map.Entry<Integer, Record> e : newData.entries()) {
            Integer rId = e.getKey();
            Record r = e.getValue();
            Integer documentId = rId;
            for (Map.Entry entry : r.getX().entrySet()) {
                Object wordPosition = entry.getKey();
                Object word = entry.getValue();
                Integer topic = PHPMethods.mt_rand(0, k - 1);
                this.increase(tmp_topicCounts, topic);
                tmp_topicAssignmentOfDocumentWord.put(Arrays.asList(documentId, wordPosition), topic);
                this.increase(tmp_documentTopicCounts, Arrays.asList(documentId, topic));
                this.increase(tmp_topicWordCounts, Arrays.asList(topic, word));
            }
        }
        double alpha = trainingParameters.getAlpha();
        double beta = trainingParameters.getBeta();
        int maxIterations = trainingParameters.getMaxIterations();
        for (int iteration = 0; iteration < maxIterations; ++iteration) {
            this.logger.debug("Iteration {}", (Object)iteration);
            int changedCounter = 0;
            double perplexity = 0.0;
            double totalDatasetWords = 0.0;
            for (Map.Entry<Integer, Record> e : newData.entries()) {
                Integer rId = e.getKey();
                Record r = e.getValue();
                Integer documentId = rId;
                AssociativeArray topicAssignments = new AssociativeArray();
                for (int j = 0; j < k; ++j) {
                    topicAssignments.put((Object)j, (Object)0.0);
                }
                int totalDocumentWords = r.getX().size();
                totalDatasetWords += (double)totalDocumentWords;
                for (Map.Entry entry : r.getX().entrySet()) {
                    Integer newTopic;
                    Object wordPosition = entry.getKey();
                    Object word = entry.getValue();
                    Integer topic = (Integer)tmp_topicAssignmentOfDocumentWord.get(Arrays.asList(documentId, wordPosition));
                    this.decrease(tmp_topicCounts, topic);
                    this.decrease(tmp_documentTopicCounts, Arrays.asList(documentId, topic));
                    this.decrease(tmp_topicWordCounts, Arrays.asList(topic, word));
                    int numberOfDocumentWords = r.getX().size() - 1;
                    AssociativeArray topicProbabilities = new AssociativeArray();
                    for (int j = 0; j < k; ++j) {
                        Integer njd;
                        List<Object> topicWordKey = Arrays.asList(j, word);
                        Integer njw = (Integer)tmp_topicWordCounts.get(topicWordKey);
                        double enumerator = njw != null ? (double)njw.intValue() + beta : beta;
                        Integer njw_original = topicWordCounts.get(topicWordKey);
                        if (njw_original != null) {
                            enumerator += (double)njw_original.intValue();
                        }
                        enumerator = (njd = (Integer)tmp_documentTopicCounts.get(Arrays.asList(documentId, j))) != null ? (enumerator *= (double)njd.intValue() + alpha) : (enumerator *= alpha);
                        double denominator = (double)((Integer)tmp_topicCounts.get(j)).intValue() + beta * (double)d - 1.0;
                        denominator += (double)topicCounts.get(j).intValue();
                        topicProbabilities.put((Object)j, (Object)(enumerator / (denominator *= (double)numberOfDocumentWords + alpha * (double)k)));
                    }
                    perplexity += Math.log(Descriptives.sum(topicProbabilities.toFlatDataCollection()));
                    topic = newTopic = (Integer)SimpleRandomSampling.weightedSampling(topicProbabilities, 1, true).iterator().next();
                    tmp_topicAssignmentOfDocumentWord.put(Arrays.asList(documentId, wordPosition), topic);
                    this.increase(tmp_topicCounts, topic);
                    this.increase(tmp_documentTopicCounts, Arrays.asList(documentId, topic));
                    this.increase(tmp_topicWordCounts, Arrays.asList(topic, word));
                    topicAssignments.put((Object)topic, (Object)(TypeInference.toDouble((Object)topicAssignments.get((Object)topic)) + 1.0 / (double)totalDocumentWords));
                }
                Object mainTopic = MapMethods.selectMaxKeyValue(topicAssignments).getKey();
                if (!mainTopic.equals(r.getYPredicted())) {
                    ++changedCounter;
                }
                newData._unsafe_set(rId, new Record(r.getX(), r.getY(), mainTopic, topicAssignments));
            }
            perplexity = Math.exp(-perplexity / totalDatasetWords);
            this.logger.debug("Reassigned Records {} - Perplexity: {}", (Object)changedCounter, (Object)perplexity);
            if (changedCounter == 0) break;
        }
        storageEngine.dropBigMap("tmp_topicAssignmentOfDocumentWord", tmp_topicAssignmentOfDocumentWord);
        storageEngine.dropBigMap("tmp_documentTopicCounts", tmp_documentTopicCounts);
        storageEngine.dropBigMap("tmp_topicWordCounts", tmp_topicWordCounts);
        storageEngine.dropBigMap("tmp_topicCounts", tmp_topicCounts);
    }

    public static class TrainingParameters
    extends AbstractTrainer.AbstractTrainingParameters {
        private static final long serialVersionUID = 1L;
        private int k = 2;
        private int maxIterations = 50;
        private double alpha = 1.0;
        private double beta = 1.0;

        public int getK() {
            return this.k;
        }

        public void setK(int k) {
            this.k = k;
        }

        public int getMaxIterations() {
            return this.maxIterations;
        }

        public void setMaxIterations(int maxIterations) {
            this.maxIterations = maxIterations;
        }

        public double getAlpha() {
            return this.alpha;
        }

        public void setAlpha(double alpha) {
            this.alpha = alpha;
        }

        public double getBeta() {
            return this.beta;
        }

        public void setBeta(double beta) {
            this.beta = beta;
        }
    }

    public static class ModelParameters
    extends AbstractTrainer.AbstractModelParameters {
        private static final long serialVersionUID = 1L;
        private Integer d = 0;
        private int totalIterations;
        @BigMap(keyClass=List.class, valueClass=Integer.class, mapType=StorageEngine.MapType.HASHMAP, storageHint=StorageEngine.StorageHint.IN_CACHE, concurrent=false)
        private Map<List<Object>, Integer> topicAssignmentOfDocumentWord;
        @BigMap(keyClass=List.class, valueClass=Integer.class, mapType=StorageEngine.MapType.HASHMAP, storageHint=StorageEngine.StorageHint.IN_MEMORY, concurrent=false)
        private Map<List<Integer>, Integer> documentTopicCounts;
        @BigMap(keyClass=List.class, valueClass=Integer.class, mapType=StorageEngine.MapType.HASHMAP, storageHint=StorageEngine.StorageHint.IN_CACHE, concurrent=false)
        private Map<List<Object>, Integer> topicWordCounts;
        @BigMap(keyClass=Integer.class, valueClass=Integer.class, mapType=StorageEngine.MapType.HASHMAP, storageHint=StorageEngine.StorageHint.IN_MEMORY, concurrent=false)
        private Map<Integer, Integer> documentWordCounts;
        @BigMap(keyClass=Integer.class, valueClass=Integer.class, mapType=StorageEngine.MapType.HASHMAP, storageHint=StorageEngine.StorageHint.IN_MEMORY, concurrent=false)
        private Map<Integer, Integer> topicCounts;

        protected ModelParameters(StorageEngine storageEngine) {
            super(storageEngine);
        }

        public Integer getD() {
            return this.d;
        }

        protected void setD(Integer d) {
            this.d = d;
        }

        public int getTotalIterations() {
            return this.totalIterations;
        }

        protected void setTotalIterations(int totalIterations) {
            this.totalIterations = totalIterations;
        }

        public Map<List<Object>, Integer> getTopicAssignmentOfDocumentWord() {
            return this.topicAssignmentOfDocumentWord;
        }

        protected void setTopicAssignmentOfDocumentWord(Map<List<Object>, Integer> topicAssignmentOfDocumentWord) {
            this.topicAssignmentOfDocumentWord = topicAssignmentOfDocumentWord;
        }

        public Map<List<Integer>, Integer> getDocumentTopicCounts() {
            return this.documentTopicCounts;
        }

        protected void setDocumentTopicCounts(Map<List<Integer>, Integer> documentTopicCounts) {
            this.documentTopicCounts = documentTopicCounts;
        }

        public Map<List<Object>, Integer> getTopicWordCounts() {
            return this.topicWordCounts;
        }

        protected void setTopicWordCounts(Map<List<Object>, Integer> topicWordCounts) {
            this.topicWordCounts = topicWordCounts;
        }

        public Map<Integer, Integer> getDocumentWordCounts() {
            return this.documentWordCounts;
        }

        protected void setDocumentWordCounts(Map<Integer, Integer> documentWordCounts) {
            this.documentWordCounts = documentWordCounts;
        }

        public Map<Integer, Integer> getTopicCounts() {
            return this.topicCounts;
        }

        protected void setTopicCounts(Map<Integer, Integer> topicCounts) {
            this.topicCounts = topicCounts;
        }
    }
}

