/*
 * Decompiled with CFR 0.152.
 */
package jsat.linear.solvers;

import jsat.linear.DenseVector;
import jsat.linear.Matrix;
import jsat.linear.Vec;

public class ConjugateGradient {
    public static Vec solve(double eps, Matrix A, Vec x, Vec b) {
        if (!A.isSquare()) {
            throw new ArithmeticException("A must be a square (symmetric & positive definite) matrix");
        }
        if (A.rows() != b.length() || A.rows() != x.length()) {
            throw new ArithmeticException("Matrix A dimensions do not agree with x and b");
        }
        int k = 0;
        Vec r_k = b.subtract(A.multiply(x));
        Vec p_k = r_k.clone();
        double RdR = r_k.dot(r_k);
        do {
            Vec Apk = A.multiply(p_k);
            double alpha_k = RdR / p_k.dot(Apk);
            x.mutableAdd(alpha_k, p_k);
            r_k.mutableAdd(-alpha_k, Apk);
            double newRdR = r_k.dot(r_k);
            if (newRdR < eps * eps) {
                return x;
            }
            double beta_k = newRdR / RdR;
            p_k.mutableMultiply(beta_k);
            p_k.mutableAdd(r_k);
            RdR = newRdR;
        } while (k++ < A.rows());
        return x;
    }

    public static Vec solve(Matrix A, Vec b) {
        DenseVector x = new DenseVector(b.length());
        return ConjugateGradient.solve(1.0E-10, A, x, b);
    }

    public static Vec solve(double eps, Matrix A, Vec x, Vec b, Matrix Minv) {
        if (!A.isSquare() || !Minv.isSquare()) {
            throw new ArithmeticException("A and Minv must be square (symmetric & positive definite) matrix");
        }
        if (A.rows() != b.length() || A.rows() != x.length()) {
            throw new ArithmeticException("Matrix A dimensions do not agree with x and b");
        }
        if (A.rows() != Minv.rows() || A.cols() != Minv.cols()) {
            throw new ArithmeticException("Matrix A and Minv do not have the same dimmentions");
        }
        int k = 0;
        Vec r_k = b.subtract(A.multiply(x));
        Vec z_k = Minv.multiply(r_k);
        Vec p_k = z_k.clone();
        double rkzk = r_k.dot(z_k);
        do {
            Vec Apk = A.multiply(p_k);
            double alpha = rkzk / p_k.dot(Apk);
            x.mutableAdd(alpha, p_k);
            r_k.mutableSubtract(alpha, Apk);
            if (r_k.dot(r_k) < eps * eps) {
                return x;
            }
            z_k = Minv.multiply(r_k);
            double newRkZk = r_k.dot(z_k);
            double beta = newRkZk / rkzk;
            rkzk = newRkZk;
            p_k.mutableMultiply(beta);
            p_k.mutableAdd(z_k);
        } while (k++ < A.rows());
        return x;
    }

    public static Vec solveCGNR(double eps, Matrix A, Vec x, Vec b) {
        if (A.rows() != b.length()) {
            throw new ArithmeticException("Dimensions do not agree for Matrix A and Vector b");
        }
        if (A.cols() != x.length()) {
            throw new ArithmeticException("Dimensions do not agree for Matrix A and Vector x");
        }
        Matrix At = A.transpose();
        Matrix AtA = At.multiply(A);
        Vec AtB = At.multiply(b);
        return ConjugateGradient.solve(eps, AtA, x, AtB);
    }

    public static Vec solveCGNR(Matrix A, Vec b) {
        DenseVector x = new DenseVector(A.cols());
        return ConjugateGradient.solveCGNR(1.0E-10, A, x, b);
    }
}

