/*
 * Decompiled with CFR 0.152.
 */
package com.datumbox.framework.core.machinelearning.clustering;

import com.datumbox.framework.common.Configuration;
import com.datumbox.framework.common.concurrency.ForkJoinStream;
import com.datumbox.framework.common.concurrency.StreamMethods;
import com.datumbox.framework.common.dataobjects.AssociativeArray;
import com.datumbox.framework.common.storage.interfaces.StorageEngine;
import com.datumbox.framework.core.common.dataobjects.Dataframe;
import com.datumbox.framework.core.common.dataobjects.Record;
import com.datumbox.framework.core.common.utilities.MapMethods;
import com.datumbox.framework.core.machinelearning.common.abstracts.AbstractTrainer;
import com.datumbox.framework.core.machinelearning.common.abstracts.modelers.AbstractClusterer;
import com.datumbox.framework.core.machinelearning.common.interfaces.PredictParallelizable;
import com.datumbox.framework.core.machinelearning.common.interfaces.TrainParallelizable;
import com.datumbox.framework.core.mathematics.distances.Distance;
import com.datumbox.framework.core.statistics.descriptivestatistics.Descriptives;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;

public class HierarchicalAgglomerative
extends AbstractClusterer<Cluster, ModelParameters, TrainingParameters>
implements PredictParallelizable,
TrainParallelizable {
    private boolean parallelized = true;
    protected final ForkJoinStream streamExecutor = new ForkJoinStream(this.knowledgeBase.getConfiguration().getConcurrencyConfiguration());

    protected HierarchicalAgglomerative(TrainingParameters trainingParameters, Configuration configuration) {
        super(trainingParameters, configuration);
    }

    protected HierarchicalAgglomerative(String storageName, Configuration configuration) {
        super(storageName, configuration);
    }

    @Override
    public boolean isParallelized() {
        return this.parallelized;
    }

    @Override
    public void setParallelized(boolean parallelized) {
        this.parallelized = parallelized;
    }

    @Override
    protected void _predict(Dataframe newData) {
        this._predictDatasetParallel(newData, this.knowledgeBase.getStorageEngine(), this.knowledgeBase.getConfiguration().getConcurrencyConfiguration());
    }

    @Override
    public PredictParallelizable.Prediction _predictRecord(Record r) {
        ModelParameters modelParameters = (ModelParameters)this.knowledgeBase.getModelParameters();
        Map clusterMap = modelParameters.getClusterMap();
        AssociativeArray clusterDistances = new AssociativeArray();
        for (Map.Entry e : clusterMap.entrySet()) {
            Integer clusterId = e.getKey();
            Cluster c = (Cluster)e.getValue();
            double distance = this.calculateDistance(r, c.getCentroid());
            clusterDistances.put((Object)clusterId, (Object)distance);
        }
        Descriptives.normalize(clusterDistances);
        return new PredictParallelizable.Prediction(this.getSelectedClusterFromDistances(clusterDistances), clusterDistances);
    }

    @Override
    protected void _fit(Dataframe trainingData) {
        ModelParameters modelParameters = (ModelParameters)this.knowledgeBase.getModelParameters();
        Set<Object> goldStandardClasses = modelParameters.getGoldStandardClasses();
        for (Record r : trainingData) {
            Object theClass = r.getY();
            if (theClass == null) continue;
            goldStandardClasses.add(theClass);
        }
        this.calculateClusters(trainingData);
        this.clearClusters();
    }

    private double calculateDistance(Record r1, Record r2) {
        double distance;
        TrainingParameters trainingParameters = (TrainingParameters)this.knowledgeBase.getTrainingParameters();
        TrainingParameters.Distance distanceMethod = trainingParameters.getDistanceMethod();
        if (distanceMethod == TrainingParameters.Distance.EUCLIDIAN) {
            distance = Distance.euclidean(r1.getX(), r2.getX());
        } else if (distanceMethod == TrainingParameters.Distance.MANHATTAN) {
            distance = Distance.manhattan(r1.getX(), r2.getX());
        } else if (distanceMethod == TrainingParameters.Distance.MAXIMUM) {
            distance = Distance.maximum(r1.getX(), r2.getX());
        } else {
            throw new IllegalArgumentException("Unsupported Distance method.");
        }
        return distance;
    }

    private Object getSelectedClusterFromDistances(AssociativeArray clusterDistances) {
        Map.Entry<Object, Object> minEntry = MapMethods.selectMinKeyValue(clusterDistances);
        return minEntry.getKey();
    }

    private void calculateClusters(Dataframe trainingData) {
        ModelParameters modelParameters = (ModelParameters)this.knowledgeBase.getModelParameters();
        TrainingParameters trainingParameters = (TrainingParameters)this.knowledgeBase.getTrainingParameters();
        Map clusterMap = modelParameters.getClusterMap();
        StorageEngine storageEngine = this.knowledgeBase.getStorageEngine();
        Map tmp_distanceArray = storageEngine.getBigMap("tmp_distanceArray", List.class, Double.class, StorageEngine.MapType.HASHMAP, StorageEngine.StorageHint.IN_CACHE, true, true);
        Map tmp_minClusterDistanceId = storageEngine.getBigMap("tmp_minClusterDistanceId", Integer.class, Integer.class, StorageEngine.MapType.HASHMAP, StorageEngine.StorageHint.IN_CACHE, true, true);
        Integer clusterId = 0;
        for (Record r : trainingData.values()) {
            Cluster c = new Cluster((int)clusterId);
            c.add(r);
            c.updateClusterParameters();
            clusterMap.put(clusterId, c);
            clusterId = clusterId + 1;
        }
        this.streamExecutor.forEach(StreamMethods.stream(clusterMap.entrySet().stream(), (boolean)this.isParallelized()), entry1 -> {
            Integer clusterId1 = (Integer)entry1.getKey();
            Cluster c1 = (Cluster)entry1.getValue();
            for (Map.Entry entry2 : clusterMap.entrySet()) {
                Integer clusterId2 = (Integer)entry2.getKey();
                Cluster c2 = (Cluster)entry2.getValue();
                double distance = Double.MAX_VALUE;
                if (!Objects.equals(clusterId1, clusterId2)) {
                    distance = this.calculateDistance(c1.getCentroid(), c2.getCentroid());
                }
                tmp_distanceArray.put(Arrays.asList(clusterId1, clusterId2), distance);
                tmp_distanceArray.put(Arrays.asList(clusterId2, clusterId1), distance);
                Integer minDistanceId = (Integer)tmp_minClusterDistanceId.get(clusterId1);
                if (minDistanceId != null) {
                    Integer[] integerArray = new Integer[]{clusterId1, minDistanceId};
                    if (!(distance < (Double)tmp_distanceArray.get(Arrays.asList(integerArray)))) continue;
                }
                tmp_minClusterDistanceId.put(clusterId1, clusterId2);
            }
        });
        boolean continueMerging = true;
        while (continueMerging) {
            continueMerging = this.mergeClosest(tmp_minClusterDistanceId, tmp_distanceArray);
            int activeClusters = 0;
            for (Cluster c : clusterMap.values()) {
                if (!c.isActive()) continue;
                ++activeClusters;
            }
            if (!((double)activeClusters <= trainingParameters.getMinClustersThreshold())) continue;
            continueMerging = false;
        }
        Iterator it = clusterMap.entrySet().iterator();
        while (it.hasNext()) {
            Map.Entry entry = it.next();
            Integer cId = entry.getKey();
            Cluster cluster = (Cluster)entry.getValue();
            if (cluster.isActive()) {
                cluster.updateClusterParameters();
                clusterMap.put(cId, cluster);
                continue;
            }
            it.remove();
        }
        storageEngine.dropBigMap("tmp_distanceArray", tmp_distanceArray);
        storageEngine.dropBigMap("tmp_minClusterDistanceId", tmp_minClusterDistanceId);
    }

    private boolean mergeClosest(Map<Integer, Integer> minClusterDistanceId, Map<List<Object>, Double> distanceArray) {
        ModelParameters modelParameters = (ModelParameters)this.knowledgeBase.getModelParameters();
        TrainingParameters trainingParameters = (TrainingParameters)this.knowledgeBase.getTrainingParameters();
        Map clusterMap = modelParameters.getClusterMap();
        Integer minClusterId = null;
        double minDistance = Double.MAX_VALUE;
        for (Map.Entry entry2 : clusterMap.entrySet()) {
            Integer clusterId = entry2.getKey();
            if (!((Cluster)entry2.getValue()).isActive()) continue;
            Integer[] integerArray = new Integer[]{clusterId, minClusterDistanceId.get(clusterId)};
            double distance = distanceArray.get(Arrays.asList(integerArray));
            if (!(distance < minDistance)) continue;
            minClusterId = clusterId;
            minDistance = distance;
        }
        if (minDistance >= trainingParameters.getMaxDistanceThreshold()) {
            return false;
        }
        Integer clusterThatMergesId = minClusterId;
        Integer clusterToBeMergedId = minClusterDistanceId.get(clusterThatMergesId);
        Cluster c1 = (Cluster)clusterMap.get(clusterThatMergesId);
        Cluster c2 = (Cluster)clusterMap.get(clusterToBeMergedId);
        double c1Size = c1.size();
        double c2Size = c2.size();
        c1.merge(c2);
        clusterMap.put(clusterThatMergesId, c1);
        c2.setActive(false);
        clusterMap.put(clusterToBeMergedId, c2);
        TrainingParameters.Linkage linkageMethod = trainingParameters.getLinkageMethod();
        this.streamExecutor.forEach(StreamMethods.stream(clusterMap.entrySet().stream(), (boolean)this.isParallelized()), entry -> {
            Integer clusterId = (Integer)entry.getKey();
            Cluster ci = (Cluster)entry.getValue();
            if (ci.isActive()) {
                double distance;
                if (Objects.equals(clusterThatMergesId, clusterId)) {
                    distance = Double.MAX_VALUE;
                } else if (linkageMethod == TrainingParameters.Linkage.SINGLE) {
                    double c1ciDistance = (Double)distanceArray.get(Arrays.asList(clusterThatMergesId, clusterId));
                    double c2ciDistance = (Double)distanceArray.get(Arrays.asList(clusterToBeMergedId, clusterId));
                    distance = Math.min(c1ciDistance, c2ciDistance);
                } else if (linkageMethod == TrainingParameters.Linkage.COMPLETE) {
                    double c1ciDistance = (Double)distanceArray.get(Arrays.asList(clusterThatMergesId, clusterId));
                    double c2ciDistance = (Double)distanceArray.get(Arrays.asList(clusterToBeMergedId, clusterId));
                    distance = Math.max(c1ciDistance, c2ciDistance);
                } else if (linkageMethod == TrainingParameters.Linkage.AVERAGE) {
                    double c1ciDistance = (Double)distanceArray.get(Arrays.asList(clusterThatMergesId, clusterId));
                    double c2ciDistance = (Double)distanceArray.get(Arrays.asList(clusterToBeMergedId, clusterId));
                    distance = (c1ciDistance * c1Size + c2ciDistance * c2Size) / (c1Size + c2Size);
                } else {
                    distance = this.calculateDistance(c1.getCentroid(), ci.getCentroid());
                }
                distanceArray.put(Arrays.asList(clusterThatMergesId, clusterId), distance);
                distanceArray.put(Arrays.asList(clusterId, clusterThatMergesId), distance);
            }
        });
        this.streamExecutor.forEach(StreamMethods.stream(clusterMap.entrySet().stream(), (boolean)this.isParallelized()), entry1 -> {
            Integer minDistanceId;
            Integer id1 = (Integer)entry1.getKey();
            if (((Cluster)entry1.getValue()).isActive() && (Objects.equals(minDistanceId = (Integer)minClusterDistanceId.get(id1), clusterThatMergesId) || Objects.equals(minDistanceId, clusterToBeMergedId))) {
                Integer newMinDistanceId = id1;
                for (Map.Entry entry2 : clusterMap.entrySet()) {
                    Integer id2 = (Integer)entry2.getKey();
                    if (!((Cluster)entry2.getValue()).isActive()) continue;
                    Integer[] integerArray = new Integer[]{id1, id2};
                    Integer[] integerArray2 = new Integer[]{id1, newMinDistanceId};
                    if (!((Double)distanceArray.get(Arrays.asList(integerArray)) < (Double)distanceArray.get(Arrays.asList(integerArray2)))) continue;
                    newMinDistanceId = id2;
                }
                minClusterDistanceId.put(id1, newMinDistanceId);
            }
        });
        return true;
    }

    public static class TrainingParameters
    extends AbstractTrainer.AbstractTrainingParameters {
        private static final long serialVersionUID = 1L;
        private Linkage linkageMethod = Linkage.COMPLETE;
        private Distance distanceMethod = Distance.EUCLIDIAN;
        private double maxDistanceThreshold = Double.MAX_VALUE;
        private double minClustersThreshold = 2.0;

        public Linkage getLinkageMethod() {
            return this.linkageMethod;
        }

        public void setLinkageMethod(Linkage linkageMethod) {
            this.linkageMethod = linkageMethod;
        }

        public Distance getDistanceMethod() {
            return this.distanceMethod;
        }

        public void setDistanceMethod(Distance distanceMethod) {
            this.distanceMethod = distanceMethod;
        }

        public double getMaxDistanceThreshold() {
            return this.maxDistanceThreshold;
        }

        public void setMaxDistanceThreshold(double maxDistanceThreshold) {
            this.maxDistanceThreshold = maxDistanceThreshold;
        }

        public double getMinClustersThreshold() {
            return this.minClustersThreshold;
        }

        public void setMinClustersThreshold(double minClustersThreshold) {
            this.minClustersThreshold = minClustersThreshold;
        }

        public static enum Distance {
            EUCLIDIAN,
            MANHATTAN,
            MAXIMUM;

        }

        public static enum Linkage {
            AVERAGE,
            SINGLE,
            COMPLETE;

        }
    }

    public static class ModelParameters
    extends AbstractClusterer.AbstractModelParameters<Cluster> {
        private static final long serialVersionUID = 1L;

        protected ModelParameters(StorageEngine storageEngine) {
            super(storageEngine);
        }
    }

    public static class Cluster
    extends AbstractClusterer.AbstractCluster {
        private static final long serialVersionUID = 1L;
        private Record centroid = new Record(new AssociativeArray(), null);
        private boolean active = true;
        private final AssociativeArray xi_sum = new AssociativeArray();

        protected Cluster(int clusterId) {
            super(clusterId);
        }

        public Record getCentroid() {
            return this.centroid;
        }

        protected void merge(Cluster c) {
            this.xi_sum.addValues(c.xi_sum);
            this.size += c.size;
        }

        protected boolean updateClusterParameters() {
            boolean changed = false;
            AssociativeArray centoidValues = this.xi_sum.copy();
            if (this.size > 0) {
                centoidValues.multiplyValues(1.0 / (double)this.size);
            }
            if (!this.centroid.getX().equals((Object)centoidValues)) {
                changed = true;
                this.centroid = new Record(centoidValues, this.centroid.getY());
            }
            return changed;
        }

        protected boolean isActive() {
            return this.active;
        }

        protected void setActive(boolean active) {
            this.active = active;
        }

        @Override
        protected void add(Record r) {
            ++this.size;
            this.xi_sum.addValues(r.getX());
        }

        @Override
        protected void remove(Record r) {
            throw new UnsupportedOperationException("Remove operation is not supported.");
        }

        @Override
        protected void clear() {
            this.xi_sum.clear();
        }
    }
}

