/*
 * Decompiled with CFR 0.152.
 */
package jsat.clustering.dissimilarity;

import java.util.List;
import java.util.Set;
import jsat.classifiers.DataPoint;
import jsat.clustering.dissimilarity.LanceWilliamsDissimilarity;
import jsat.clustering.dissimilarity.UpdatableClusterDissimilarity;
import jsat.linear.distancemetrics.DistanceMetric;
import jsat.linear.distancemetrics.EuclideanDistance;

public class CentroidDissimilarity
extends LanceWilliamsDissimilarity
implements UpdatableClusterDissimilarity {
    public CentroidDissimilarity() {
        this(new EuclideanDistance());
    }

    public CentroidDissimilarity(DistanceMetric dm) {
        super(dm);
    }

    @Override
    public CentroidDissimilarity clone() {
        return new CentroidDissimilarity(this.dm.clone());
    }

    @Override
    public double dissimilarity(List<DataPoint> a, List<DataPoint> b) {
        double sumDIss = 0.0;
        for (DataPoint ai : a) {
            for (DataPoint bi : b) {
                sumDIss += this.distance(ai, bi);
            }
        }
        return sumDIss / (double)(a.size() * b.size());
    }

    @Override
    public double dissimilarity(Set<Integer> a, Set<Integer> b, double[][] distanceMatrix) {
        double sumDiss = 0.0;
        for (int ai : a) {
            for (int bi : b) {
                sumDiss += CentroidDissimilarity.getDistance(distanceMatrix, ai, bi);
            }
        }
        return sumDiss / (double)(a.size() * b.size());
    }

    @Override
    public double dissimilarity(int i, int ni, int j, int nj, double[][] distanceMatrix) {
        return CentroidDissimilarity.getDistance(distanceMatrix, i, j);
    }

    @Override
    public double dissimilarity(int i, int ni, int j, int nj, int k, int nk, double[][] distanceMatrix) {
        double iPj = ni + nj;
        double ai = (double)ni / iPj;
        double aj = (double)nj / iPj;
        double b = (double)(-ni * nj) / iPj * iPj;
        return ai * CentroidDissimilarity.getDistance(distanceMatrix, i, k) + aj * CentroidDissimilarity.getDistance(distanceMatrix, j, k) + b * CentroidDissimilarity.getDistance(distanceMatrix, i, j);
    }

    @Override
    protected double aConst(boolean iFlag, int ni, int nj, int nk) {
        double denom = ni + nj;
        if (iFlag) {
            return (double)ni / denom;
        }
        return (double)nj / denom;
    }

    @Override
    protected double bConst(int ni, int nj, int nk) {
        double nipj = ni + nj;
        return (double)(-ni) * (double)nj / (nipj * nipj);
    }

    @Override
    protected double cConst(int ni, int nj, int nk) {
        return 0.0;
    }
}

