/*
 * Decompiled with CFR 0.152.
 */
package jsat.classifiers.bayesian.graphicalmodel;

import java.util.Iterator;
import java.util.Set;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.DataPoint;
import jsat.classifiers.bayesian.graphicalmodel.DiscreteBayesNetwork;
import jsat.math.SpecialMath;
import jsat.utils.IntList;
import jsat.utils.IntSet;
import jsat.utils.ListUtils;

public class K2NetworkLearner
extends DiscreteBayesNetwork {
    private static final long serialVersionUID = -9681177007308829L;
    private int[] ri;
    private int maxParents;

    public void setMaxParents(int maxParents) {
        this.maxParents = maxParents;
    }

    public int getMaxParents() {
        return Math.max(this.maxParents, 0);
    }

    public void learnNetwork(ClassificationDataSet D) {
        IntList varOrder = new IntList(D.getNumCategoricalVars() + 1);
        varOrder.add(D.getNumCategoricalVars());
        ListUtils.addRange(varOrder, 0, D.getNumCategoricalVars(), 1);
        this.ri = new int[varOrder.size()];
        Iterator iterator = varOrder.iterator();
        while (iterator.hasNext()) {
            int i = (Integer)iterator.next();
            if (i == D.getNumCategoricalVars()) {
                this.ri[i] = D.getClassSize();
                continue;
            }
            this.ri[i] = D.getCategories()[i].getNumOfCategories();
        }
        int u = this.maxParents;
        if (u <= 0) {
            u = this.ri.length;
        }
        IntSet preceding = new IntSet();
        Iterator iterator2 = varOrder.iterator();
        while (iterator2.hasNext()) {
            int i = (Integer)iterator2.next();
            IntSet pi = new IntSet();
            double pOld = this.f(i, pi, D);
            boolean OKToProceed = true;
            IntSet candidates = new IntSet(preceding);
            while (OKToProceed && pi.size() < u && !candidates.isEmpty()) {
                double pNew = Double.NEGATIVE_INFINITY;
                int z = -1;
                candidates.removeAll(pi);
                Iterator iterator3 = candidates.iterator();
                while (iterator3.hasNext()) {
                    int candidate = (Integer)iterator3.next();
                    pi.add(Integer.valueOf(candidate));
                    double tmp = this.f(i, pi, D);
                    if (tmp > pNew) {
                        pNew = tmp;
                        z = candidate;
                    }
                    pi.remove((Object)candidate);
                }
                if (pNew > pOld) {
                    pOld = pNew;
                    pi.add(Integer.valueOf(z));
                    continue;
                }
                OKToProceed = false;
            }
            Iterator iterator4 = pi.iterator();
            while (iterator4.hasNext()) {
                int parrent = (Integer)iterator4.next();
                this.depends(parrent, i);
            }
            preceding.add(Integer.valueOf(i));
        }
        this.ri = null;
    }

    @Override
    public void train(ClassificationDataSet dataSet) {
        if (this.dag.getNodes().isEmpty() || this.dag.getParents(dataSet.getNumCategoricalVars()).isEmpty()) {
            this.learnNetwork(dataSet);
        }
        super.train(dataSet);
    }

    private double query(int[] classes, int[] values, ClassificationDataSet D) {
        double count = 1.0;
        for (int i = 0; i < D.getSampleSize(); ++i) {
            int j;
            DataPoint dp = D.getDataPoint(i);
            for (j = 0; j < classes.length; ++j) {
                if (classes[j] == D.getNumCategoricalVars()) {
                    if (D.getDataPointCategory(i) == values[j]) continue;
                    j = classes.length + 1;
                    continue;
                }
                if (dp.getCategoricalValue(j) == values[j]) continue;
                j = classes.length + 1;
            }
            if (j != classes.length) continue;
            count += dp.getWeight();
        }
        return count;
    }

    public double f(int i, Set<Integer> pi, ClassificationDataSet D) {
        double term2 = 0.0;
        double Nijk = 0.0;
        if (pi.isEmpty()) {
            int[] classes = new int[]{i};
            int[] values = new int[1];
            int k = 0;
            while (k < this.ri[i]) {
                values[0] = k++;
                double count = this.query(classes, values, D);
                Nijk += count;
                term2 += SpecialMath.lnGamma(count + 1.0);
            }
            return SpecialMath.lnGamma(this.ri[i]) - SpecialMath.lnGamma(Nijk + (double)this.ri[i]) + term2;
        }
        double fullProduct = 0.0;
        int[] classes = new int[pi.size() + 1];
        int[] values = new int[pi.size() + 1];
        int c = 0;
        for (int clas : pi) {
            classes[c++] = clas;
        }
        classes[c] = i;
        do {
            int pos;
            Nijk = 0.0;
            term2 = 0.0;
            int k = 0;
            while (k < this.ri[i]) {
                values[pi.size()] = k++;
                double count = this.query(classes, values, D);
                Nijk += count;
                term2 += SpecialMath.lnGamma(count + 1.0);
            }
            fullProduct += SpecialMath.lnGamma(this.ri[i]) - SpecialMath.lnGamma(Nijk + (double)this.ri[i]) + term2;
            int n = pos = 0;
            values[n] = values[n] + 1;
            values[pi.size()] = 0;
            while (values[pos] >= this.ri[classes[pos]] && pos < pi.size()) {
                values[pos++] = 0;
                int n2 = pos;
                values[n2] = values[n2] + 1;
            }
        } while (values[pi.size()] == 0);
        return fullProduct;
    }
}

