/*
 * Decompiled with CFR 0.152.
 */
package jsat.classifiers.neuralnetwork.regularizers;

import java.util.ArrayList;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import java.util.logging.Level;
import java.util.logging.Logger;
import jsat.classifiers.neuralnetwork.regularizers.WeightRegularizer;
import jsat.linear.Matrix;
import jsat.linear.Vec;

public class Max2NormRegularizer
implements WeightRegularizer {
    private static final long serialVersionUID = 1989826758516880355L;
    private double maxNorm;

    public Max2NormRegularizer(double maxNorm) {
        this.setMaxNorm(maxNorm);
    }

    public void setMaxNorm(double maxNorm) {
        if (Double.isNaN(maxNorm) || Double.isInfinite(maxNorm) || maxNorm <= 0.0) {
            throw new IllegalArgumentException("The maximum norm must be a positive constant, not " + maxNorm);
        }
        this.maxNorm = maxNorm;
    }

    public double getMaxNorm() {
        return this.maxNorm;
    }

    @Override
    public void applyRegularization(Matrix W, Vec b) {
        for (int i = 0; i < W.rows(); ++i) {
            Vec W_li = W.getRowView(i);
            double norm = W_li.pNorm(2.0);
            if (!(norm >= this.maxNorm)) continue;
            W_li.mutableMultiply(this.maxNorm / norm);
            double oldB_i = b.get(i);
            b.set(i, oldB_i * this.maxNorm / norm);
        }
    }

    @Override
    public void applyRegularization(final Matrix W, final Vec b, ExecutorService ex) {
        ArrayList futures = new ArrayList(W.rows());
        int indx = 0;
        while (indx < W.rows()) {
            final int n = indx++;
            futures.add(ex.submit(new Runnable(){

                @Override
                public void run() {
                    Vec W_li = W.getRowView(n);
                    double norm = W_li.pNorm(2.0);
                    if (norm >= Max2NormRegularizer.this.maxNorm) {
                        W_li.mutableMultiply(Max2NormRegularizer.this.maxNorm / norm);
                        double oldB_i = b.get(n);
                        b.set(n, oldB_i * Max2NormRegularizer.this.maxNorm / norm);
                    }
                }
            }));
        }
        try {
            for (Future future : futures) {
                future.get();
            }
        }
        catch (InterruptedException ex1) {
            Logger.getLogger(Max2NormRegularizer.class.getName()).log(Level.SEVERE, null, ex1);
        }
        catch (ExecutionException ex1) {
            Logger.getLogger(Max2NormRegularizer.class.getName()).log(Level.SEVERE, null, ex1);
        }
    }

    @Override
    public double applyRegularizationToRow(Vec w, double b) {
        double norm = w.pNorm(2.0);
        if (norm >= this.maxNorm) {
            w.mutableMultiply(this.maxNorm / norm);
            return b * this.maxNorm / norm;
        }
        return b;
    }

    @Override
    public Max2NormRegularizer clone() {
        return new Max2NormRegularizer(this.maxNorm);
    }
}

