/*
 * Decompiled with CFR 0.152.
 */
package com.datumbox.framework.core.machinelearning.classification;

import com.datumbox.framework.common.Configuration;
import com.datumbox.framework.common.dataobjects.AssociativeArray;
import com.datumbox.framework.common.dataobjects.TypeInference;
import com.datumbox.framework.common.storage.interfaces.BigMap;
import com.datumbox.framework.common.storage.interfaces.StorageEngine;
import com.datumbox.framework.common.utilities.RandomGenerator;
import com.datumbox.framework.core.common.dataobjects.Dataframe;
import com.datumbox.framework.core.common.dataobjects.Record;
import com.datumbox.framework.core.machinelearning.common.abstracts.AbstractTrainer;
import com.datumbox.framework.core.machinelearning.common.abstracts.modelers.AbstractClassifier;
import com.datumbox.framework.core.machinelearning.common.interfaces.PredictParallelizable;
import com.datumbox.framework.core.statistics.descriptivestatistics.Descriptives;
import java.util.HashMap;
import java.util.Map;
import java.util.Set;
import libsvm.svm;
import libsvm.svm_model;
import libsvm.svm_node;
import libsvm.svm_parameter;
import libsvm.svm_problem;

public class SupportVectorMachine
extends AbstractClassifier<ModelParameters, TrainingParameters>
implements PredictParallelizable {
    private boolean parallelized;

    protected SupportVectorMachine(TrainingParameters trainingParameters, Configuration configuration) {
        super(trainingParameters, configuration);
        svm.rand.setSeed(RandomGenerator.getThreadLocalRandom().nextLong());
        this.parallelized = true;
    }

    protected SupportVectorMachine(String storageName, Configuration configuration) {
        super(storageName, configuration);
        svm.rand.setSeed(RandomGenerator.getThreadLocalRandom().nextLong());
        this.parallelized = true;
    }

    @Override
    public boolean isParallelized() {
        return this.parallelized;
    }

    @Override
    public void setParallelized(boolean parallelized) {
        this.parallelized = parallelized;
    }

    @Override
    protected void _predict(Dataframe newData) {
        this._predictDatasetParallel(newData, this.knowledgeBase.getStorageEngine(), this.knowledgeBase.getConfiguration().getConcurrencyConfiguration());
    }

    @Override
    public PredictParallelizable.Prediction _predictRecord(Record r) {
        AssociativeArray predictionScores = this.calculateClassScores(r.getX());
        Object predictedClass = this.getSelectedClassFromClassScores(predictionScores);
        Descriptives.normalize(predictionScores);
        return new PredictParallelizable.Prediction(predictedClass, predictionScores);
    }

    @Override
    protected void _fit(Dataframe trainingData) {
        ((TrainingParameters)this.knowledgeBase.getTrainingParameters()).getSvmParameter().probability = 1;
        ModelParameters modelParameters = (ModelParameters)this.knowledgeBase.getModelParameters();
        Map<Object, Integer> featureIds = modelParameters.getFeatureIds();
        Map<Object, Integer> classIds = modelParameters.getClassIds();
        Set<Object> classesSet = modelParameters.getClasses();
        int classId = 0;
        int featureId = 0;
        for (Record r : trainingData) {
            Object theClass = r.getY();
            if (classesSet.add(theClass)) {
                classIds.put(theClass, classId++);
            }
            for (Map.Entry entry : r.getX().entrySet()) {
                Object feature = entry.getKey();
                if (featureIds.putIfAbsent(feature, featureId) != null) continue;
                ++featureId;
            }
        }
        this.libSVMTrainer(trainingData);
    }

    private void libSVMTrainer(Dataframe trainingData) {
        ModelParameters modelParameters = (ModelParameters)this.knowledgeBase.getModelParameters();
        Map<Object, Integer> featureIds = modelParameters.getFeatureIds();
        Map<Object, Integer> classIds = modelParameters.getClassIds();
        int n = trainingData.size();
        int sparseD = featureIds.size();
        svm_problem prob = new svm_problem();
        prob.l = n;
        prob.y = new double[n];
        prob.x = new svm_node[n][sparseD];
        int rowId = 0;
        for (Record r : trainingData.values()) {
            Object theClass = r.getY();
            int classId = classIds.get(theClass);
            prob.y[rowId] = classId;
            for (Map.Entry entry : r.getX().entrySet()) {
                Object feature = entry.getKey();
                int featureId = featureIds.get(feature);
                Double value = TypeInference.toDouble(entry.getValue());
                if (value == null) {
                    value = 0.0;
                }
                svm_node node = new svm_node();
                node.index = featureId + 1;
                node.value = value;
                prob.x[rowId][featureId] = node;
            }
            for (int featureId = 0; featureId < sparseD; ++featureId) {
                if (prob.x[rowId][featureId] != null) continue;
                svm_node node = new svm_node();
                node.index = featureId + 1;
                node.value = 0.0;
                prob.x[rowId][featureId] = node;
            }
            ++rowId;
        }
        svm_parameter params = ((TrainingParameters)this.knowledgeBase.getTrainingParameters()).getSvmParameter();
        svm.svm_set_print_string_function(s -> {
            if (s != null) {
                this.logger.debug(s.trim());
            }
        });
        svm_model model = svm.svm_train((svm_problem)prob, (svm_parameter)params);
        modelParameters.setSvmModel(model);
    }

    private AssociativeArray calculateClassScores(AssociativeArray x) {
        ModelParameters modelParameters = (ModelParameters)this.knowledgeBase.getModelParameters();
        Map<Object, Integer> featureIds = modelParameters.getFeatureIds();
        Map<Object, Integer> classIds = modelParameters.getClassIds();
        svm_model model = modelParameters.getSvmModel();
        int sparseD = featureIds.size();
        int c = modelParameters.getC();
        svm_node[] xSVM = new svm_node[sparseD];
        for (Map.Entry entry : x.entrySet()) {
            Object feature = entry.getKey();
            Integer featureId = featureIds.get(feature);
            if (featureId == null) continue;
            Double value = TypeInference.toDouble(entry.getValue());
            if (value == null) {
                value = 0.0;
            }
            svm_node node = new svm_node();
            node.index = featureId + 1;
            node.value = value;
            xSVM[featureId.intValue()] = node;
        }
        for (int featureId = 0; featureId < sparseD; ++featureId) {
            if (xSVM[featureId] != null) continue;
            svm_node node = new svm_node();
            node.index = featureId + 1;
            node.value = 0.0;
            xSVM[featureId] = node;
        }
        int[] labels = new int[c];
        svm.svm_get_labels((svm_model)model, (int[])labels);
        double[] prob_estimates = new double[c];
        svm.svm_predict_probability((svm_model)model, (svm_node[])xSVM, (double[])prob_estimates);
        AssociativeArray classScores = new AssociativeArray();
        for (Map.Entry<Object, Integer> entry : classIds.entrySet()) {
            Object theClass = entry.getKey();
            int classId = entry.getValue();
            classScores.put(theClass, (Object)prob_estimates[classId]);
        }
        return classScores;
    }

    public static class TrainingParameters
    extends AbstractTrainer.AbstractTrainingParameters {
        private static final long serialVersionUID = 1L;
        private svm_parameter svmParameter = new svm_parameter();

        public TrainingParameters() {
            this.svmParameter.svm_type = 0;
            this.svmParameter.kernel_type = 0;
            this.svmParameter.degree = 3;
            this.svmParameter.gamma = 0.0;
            this.svmParameter.coef0 = 0.0;
            this.svmParameter.nu = 0.5;
            this.svmParameter.cache_size = 100.0;
            this.svmParameter.C = 1.0;
            this.svmParameter.eps = 0.001;
            this.svmParameter.p = 0.1;
            this.svmParameter.shrinking = 1;
            this.svmParameter.probability = 1;
            this.svmParameter.nr_weight = 0;
            this.svmParameter.weight_label = new int[0];
            this.svmParameter.weight = new double[0];
        }

        public svm_parameter getSvmParameter() {
            return this.svmParameter;
        }

        public void setSvmParameter(svm_parameter svmParameter) {
            this.svmParameter = svmParameter;
        }
    }

    public static class ModelParameters
    extends AbstractClassifier.AbstractModelParameters {
        private static final long serialVersionUID = 1L;
        @BigMap(keyClass=Object.class, valueClass=Integer.class, mapType=StorageEngine.MapType.HASHMAP, storageHint=StorageEngine.StorageHint.IN_MEMORY, concurrent=false)
        private Map<Object, Integer> featureIds;
        private Map<Object, Integer> classIds = new HashMap<Object, Integer>();
        private svm_model svmModel;

        protected ModelParameters(StorageEngine storageEngine) {
            super(storageEngine);
        }

        public Map<Object, Integer> getFeatureIds() {
            return this.featureIds;
        }

        protected void setFeatureIds(Map<Object, Integer> featureIds) {
            this.featureIds = featureIds;
        }

        public svm_model getSvmModel() {
            return this.svmModel;
        }

        protected void setSvmModel(svm_model svmModel) {
            this.svmModel = svmModel;
        }

        public Map<Object, Integer> getClassIds() {
            return this.classIds;
        }

        protected void setClassIds(Map<Object, Integer> classIds) {
            this.classIds = classIds;
        }
    }
}

