/*
 * Decompiled with CFR 0.152.
 */
package jsat.distributions.multivariate;

import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import jsat.classifiers.DataPoint;
import jsat.distributions.Gamma;
import jsat.distributions.multivariate.MultivariateDistributionSkeleton;
import jsat.linear.DenseVector;
import jsat.linear.Vec;
import jsat.math.Function;
import jsat.math.SpecialMath;
import jsat.math.optimization.NelderMead;

public class SymmetricDirichlet
extends MultivariateDistributionSkeleton {
    private static final long serialVersionUID = -1206894014440494142L;
    private double alpha;
    private int dim;

    public SymmetricDirichlet(double alpha, int dim) {
        this.setAlpha(alpha);
        this.setDimension(dim);
    }

    public void setDimension(int dim) {
        if (dim <= 0) {
            throw new ArithmeticException("A positive number of dimensions must be given");
        }
        this.dim = dim;
    }

    public int getDimension() {
        return this.dim;
    }

    public void setAlpha(double alpha) throws ArithmeticException {
        if (alpha <= 0.0 || Double.isNaN(alpha) || Double.isInfinite(alpha)) {
            throw new ArithmeticException("Symmetric Dirichlet Distribution parameters must be positive, " + alpha + " is invalid");
        }
        this.alpha = alpha;
    }

    public double getAlpha() {
        return this.alpha;
    }

    @Override
    public SymmetricDirichlet clone() {
        return new SymmetricDirichlet(this.alpha, this.dim);
    }

    @Override
    public double logPdf(Vec x) {
        if (x.length() != this.dim) {
            throw new ArithmeticException(this.dim + " variable distribution can not awnser a " + x.length() + " dimension variable");
        }
        double logVal = 0.0;
        int K = x.length();
        for (int i = 0; i < K; ++i) {
            logVal += Math.log(x.get(i)) * (this.alpha - 1.0);
        }
        if (Double.isInfinite(logVal = logVal + SpecialMath.lnGamma(this.alpha * (double)K) - SpecialMath.lnGamma(this.alpha) * (double)K) || Double.isNaN(logVal) || Math.abs(x.sum() - 1.0) > 1.0E-14) {
            return -1.7976931348623157E308;
        }
        return logVal;
    }

    @Override
    public double pdf(Vec x) {
        return Math.exp(this.logPdf(x));
    }

    @Override
    public <V extends Vec> boolean setUsingData(final List<V> dataSet) {
        Function logLike = new Function(){
            private static final long serialVersionUID = -3591420776536183583L;

            @Override
            public double f(double ... x) {
                return this.f(DenseVector.toDenseVec(x));
            }

            @Override
            public double f(Vec x) {
                double a = x.get(0);
                double constantTerm = SpecialMath.lnGamma(a * (double)SymmetricDirichlet.this.dim);
                constantTerm -= SpecialMath.lnGamma(a) * (double)SymmetricDirichlet.this.dim;
                double sum = 0.0;
                for (int i = 0; i < dataSet.size(); ++i) {
                    Vec s = (Vec)dataSet.get(i);
                    for (int j = 0; j < s.length(); ++j) {
                        sum += Math.log(s.get(j)) * (a - 1.0);
                    }
                }
                return -(sum + constantTerm * (double)dataSet.size());
            }
        };
        NelderMead optimize = new NelderMead();
        DenseVector guess = new DenseVector(1);
        ArrayList<Vec> guesses = new ArrayList<Vec>();
        guesses.add(guess.add(1.0));
        guesses.add(guess.add(0.1));
        guesses.add(guess.add(10.0));
        this.alpha = optimize.optimize(1.0E-10, 100, logLike, guesses).get(0);
        return true;
    }

    @Override
    public boolean setUsingDataList(final List<DataPoint> dataPoint) {
        Function logLike = new Function(){
            private static final long serialVersionUID = -1145407955317879017L;

            @Override
            public double f(double ... x) {
                return this.f(DenseVector.toDenseVec(x));
            }

            @Override
            public double f(Vec x) {
                double a = x.get(0);
                double constantTerm = SpecialMath.lnGamma(a * (double)SymmetricDirichlet.this.dim);
                constantTerm -= SpecialMath.lnGamma(a) * (double)SymmetricDirichlet.this.dim;
                double weightSum = 0.0;
                double sum = 0.0;
                for (int i = 0; i < dataPoint.size(); ++i) {
                    DataPoint dp = (DataPoint)dataPoint.get(i);
                    weightSum += dp.getWeight();
                    Vec s = dp.getNumericalValues();
                    for (int j = 0; j < s.length(); ++j) {
                        sum += Math.log(s.get(j)) * (a - 1.0) * dp.getWeight();
                    }
                }
                return -(sum + constantTerm * weightSum);
            }
        };
        NelderMead optimize = new NelderMead();
        DenseVector guess = new DenseVector(1);
        ArrayList<Vec> guesses = new ArrayList<Vec>();
        guesses.add(guess.add(1.0));
        guesses.add(guess.add(0.1));
        guesses.add(guess.add(10.0));
        this.alpha = optimize.optimize(1.0E-10, 100, logLike, guesses).get(0);
        return true;
    }

    @Override
    public List<Vec> sample(int count, Random rand) {
        ArrayList<Vec> samples = new ArrayList<Vec>(count);
        double[] gammaSamples = new Gamma(this.alpha, 1.0).sample(count * this.dim, rand);
        int samplePos = 0;
        for (int i = 0; i < count; ++i) {
            DenseVector sample = new DenseVector(this.dim);
            for (int j = 0; j < this.dim; ++j) {
                ((Vec)sample).set(j, gammaSamples[samplePos++]);
            }
            ((Vec)sample).mutableDivide(((Vec)sample).sum());
            samples.add(sample);
        }
        return samples;
    }
}

