/*
 * Decompiled with CFR 0.152.
 */
package org.rhwlab.DAG.decomp;

import java.io.PrintStream;
import org.rhwlab.DAG.Constant;
import org.rhwlab.DAG.Function;
import org.rhwlab.DAG.Node;
import org.rhwlab.DAG.Parameter;
import org.rhwlab.DAG.distributions.Gamma;
import org.rhwlab.DAG.distributions.InverseGamma;
import org.rhwlab.DAG.distributions.LogNormal;
import org.rhwlab.DAG.models.Model;
import org.rhwlab.DAG.parameters.NonNegative;
import org.rhwlab.DAG.parameters.RandomWalkParameter;
import org.rhwlab.DAG.pipeline.Pipeline;
import org.rhwlab.RNASeq.PMatrixKdeFile;

public class PMatrix
extends Function {
    int N;
    int T;
    Node expStartTime;
    Node expGrowthRate;
    Node initVariance;
    Node dispRate;
    double[] sampleTimes;
    double[] stageTimes;
    double[][] P = null;
    double stagesStartTime;

    public PMatrix(int N, double[] sampleTimes, String experiment, Model model, PMatrixKdeFile file) {
        this(experiment + "_P", model, N, sampleTimes);
        this.stageTimes = file.getStageTimes(N);
        String varName = PMatrix.getStartTimeVariableName(experiment);
        Double v = file.getValue(varName, "Mode");
        if (v == null) {
            double initStartTime = 30.0;
            this.expStartTime = new RandomWalkParameter(PMatrix.getStartTimeVariableName(experiment), model, 10.0, 30.0);
            ((Parameter)this.expStartTime).setPrior(new Gamma(9.0, 3.3333333333333335));
        } else {
            this.expStartTime = new Constant(varName, v);
        }
        varName = PMatrix.getGrowthRateVariableName(experiment);
        v = file.getValue(varName, "Mode");
        if (v == null) {
            this.expGrowthRate = new RandomWalkParameter(PMatrix.getGrowthRateVariableName(experiment), model, 0.01, 1.0);
            ((Parameter)this.expGrowthRate).setPrior(new LogNormal(0.0, 0.25));
        } else {
            this.expGrowthRate = new Constant(varName, v);
        }
        varName = PMatrix.getInitialVarianceVariableName(experiment);
        v = file.getValue(varName, "Mode");
        if (v == null) {
            this.initVariance = new NonNegative(PMatrix.getInitialVarianceVariableName(experiment), model, 100.0, 225.0);
            ((NonNegative)this.initVariance).setPrior(new InverseGamma(3.0, 3000.0));
        } else {
            this.initVariance = new Constant(varName, v);
        }
        this.addParents();
        this.update();
    }

    private PMatrix(String name, Model model, int N, double[] sampleTimes) {
        super(name, model);
        this.N = N;
        this.sampleTimes = sampleTimes;
    }

    private void addParents() {
        this.addParent(this.expGrowthRate);
        this.addParent(this.expStartTime);
        this.addParent(this.initVariance);
    }

    @Override
    public void update() {
        if (this.canUpdate()) {
            int c;
            this.T = this.sampleTimes.length;
            this.P = new double[this.T][];
            for (c = 0; c < this.T; ++c) {
                this.P[c] = new double[this.N];
            }
            for (c = 0; c < this.T; ++c) {
                int r;
                double meanTAF = this.getMean(c);
                double variance = this.getVariance(c);
                double sqrtVar = Math.sqrt(variance);
                double sum = 0.0;
                for (r = 0; r < this.N; ++r) {
                    double d = this.stageTimes[r] - meanTAF;
                    this.P[c][r] = Math.exp(-0.5 * d * d / variance) / sqrtVar;
                    if (Double.isNaN(this.P[c][r])) {
                        System.out.println("PMatrixImpl.update: NaN error 1");
                        System.out.printf("d=%f,vriance=%f\n", d, variance);
                    }
                    sum += this.P[c][r];
                }
                if (sum == 0.0) {
                    double v = 1.0 / (double)this.N;
                    for (int r2 = 0; r2 < this.N; ++r2) {
                        this.P[c][r2] = v;
                    }
                    continue;
                }
                for (r = 0; r < this.N; ++r) {
                    this.P[c][r] = this.P[c][r] / sum;
                    if (!Double.isNaN(this.P[c][r])) continue;
                    System.out.println("PMatrixImpl.update: NaN error 2");
                }
            }
            if (Pipeline.tracing) {
                System.out.printf("PMatrix.update: name=%s\n", this.getName());
            }
            if (Pipeline.dump) {
                this.dump(System.out);
            }
            this.setValue(this.P);
        }
    }

    public double[] getColumn(int c) {
        return this.P[c];
    }

    public double[] getRow(int i) {
        double[] row = new double[this.T];
        for (int c = 0; c < this.T; ++c) {
            row[c] = this.P[c][i];
        }
        return row;
    }

    public int getColumnCount() {
        return this.T;
    }

    public double getMean(int c) {
        return (Double)this.expStartTime.getValue() + (Double)this.expGrowthRate.getValue() * this.sampleTimes[c];
    }

    public double getVariance(int c) {
        return (Double)this.initVariance.getValue();
    }

    public void dump(PrintStream stream) {
        int i;
        stream.printf("PMatrix: %s\n", this.getName());
        stream.printf("startTime: %f\n", (Double)this.expStartTime.getValue());
        stream.printf("growthRate: %f\n", (Double)this.expGrowthRate.getValue());
        stream.printf("initVariance: %f\n", (Double)this.initVariance.getValue());
        stream.printf("dispRate: %f\n", (Double)this.dispRate.getValue());
        stream.printf("SampleTimes: ", new Object[0]);
        for (i = 0; i < this.sampleTimes.length; ++i) {
            stream.printf("%f,", this.sampleTimes[i]);
        }
        stream.println();
        stream.printf("StageTimes: ", new Object[0]);
        for (i = 0; i < this.stageTimes.length; ++i) {
            stream.printf("%f,", this.stageTimes[i]);
        }
        stream.println();
        for (i = 0; i < this.T; ++i) {
            double meanTAF = this.getMean(i);
            double variance = this.getVariance(i);
            stream.printf("Column:%d mean=%f,variance=%f\n", i, meanTAF, variance);
            for (int r = 0; r < this.N; ++r) {
                stream.printf("%f,", this.P[i][r]);
            }
            stream.println();
        }
    }

    public static String getGrowthRateVariableName(String experiment) {
        return String.format("GrowthRate_%s", experiment);
    }

    public static String getStartTimeVariableName(String experiment) {
        return String.format("StartTime_%s", experiment);
    }

    public static String getInitialVarianceVariableName(String experiment) {
        return String.format("InitialVariance_%s", experiment);
    }

    public static String getDispersionRateVariableName(String experiment) {
        return String.format("DispersionRate_%s", experiment);
    }
}

