/*
 * To change this license header, choose License Headers in Project Properties.
 * To change this template file, choose Tools | Templates
 * and open the template in the editor.
 */
package org.rhwlab.LMS.dataframe.viewers;

import java.awt.Shape;
import java.awt.geom.Ellipse2D;
import java.awt.geom.Rectangle2D;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import javax.swing.JPanel;
import javax.swing.JTable;
import org.apache.commons.math3.stat.correlation.SpearmansCorrelation;
import org.apache.commons.math3.stat.regression.SimpleRegression;
import org.jfree.chart.ChartFactory;
import org.jfree.chart.ChartPanel;
import org.jfree.chart.JFreeChart;
import org.jfree.chart.plot.Plot;
import org.jfree.chart.plot.PlotOrientation;
import org.jfree.chart.renderer.xy.XYItemRenderer;
import org.jfree.data.statistics.HistogramDataset;
import org.jfree.data.xy.XYSeries;
import org.jfree.data.xy.XYSeriesCollection;
import org.rhwlab.LMS.dataframe.DataFrame;
import weka.estimators.KernelEstimator;

/**
 *
 * @author gevirl
 */
public class TablePanel extends JTable {

    public TablePanel() {
        setAutoCreateRowSorter(true);
        setAutoResizeMode(JTable.AUTO_RESIZE_OFF);
        setRowSelectionAllowed(false);
        this.setColumnSelectionAllowed(true);
    }

    public void setData(DataFrame df) {
        this.setModel(df);
        this.invalidate();
        this.validate();
        this.repaint();
        int uasdhshdfi = 0;
    }

    public ChartPanel getScatterPanel(String xLabel, String yLabel, boolean logit) {
        ChartPanel ret = null;
        DataFrame df = (DataFrame) this.getModel();
        int[] selCols = this.getSelectedColumns();
        if (selCols.length == 2) {
            StringBuilder builder = new StringBuilder();
            builder.append(df.getColumnName(selCols[0]));
            for (int i = 1; i < selCols.length; ++i) {
                builder.append("|");
                builder.append(df.getColumnName(selCols[i]));
            }
            String title = builder.toString();

            SimpleRegression regress = new SimpleRegression(false);
            XYSeries series = new XYSeries(title);
            double[] x = df.getColumnAsDouble(selCols[0]);
            double[] y = df.getColumnAsDouble(selCols[1]);
            for (int i = 0; i < x.length; ++i) {
                if (logit) {
                    x[i] = Math.log10(x[i] + 1.0);
                    y[i] = Math.log10(y[i] + 1.0);
                }
                series.add(x[i], y[i]);
                regress.addData(x[i], y[i]);
            }
            SpearmansCorrelation corr = new SpearmansCorrelation();
            double spear = corr.correlation(x, y);

            title = String.format("%s %.0f %.4f", title, Math.sqrt(regress.getMeanSquareError()), spear);
            XYSeriesCollection collect = new XYSeriesCollection();
            collect.addSeries(series);
            JFreeChart chart = ChartFactory.createScatterPlot(title, xLabel, yLabel, collect, PlotOrientation.VERTICAL, false, false, false);
            XYItemRenderer rend = chart.getXYPlot().getRendererForDataset(collect);
            rend.setSeriesShape(0, new Rectangle2D.Double(-.5, -.5, 1.0, 1.0));
            ret = new ChartPanel(chart);
        }
        return ret;
    }

    public ChartPanel getKLDivergence(Comparator<String> cmp, boolean logit) {
        double precision = 0.001;
        DataFrame df = (DataFrame) this.getModel();
        int[] cols = this.getSelectedColumns();
        KernelEstimator[] est = new KernelEstimator[cols.length];
        double maxMu = Double.MAX_VALUE;

        // construct a density estimator for each column
        for (int i = 0; i < cols.length; ++i) {
            est[i] = new KernelEstimator(precision);
            double[] x = df.getColumnAsDouble(cols[i]);
            for (double v : x) {
                if (logit) {
                    est[i].addValue(Math.log(1.0 + v), 1.0);
                } else {
                    est[i].addValue(v, 1.0);
                }
            }
            int nkern = est[i].getNumKernels();
            double[] mu = est[i].getMeans();
            if (mu[nkern - 1] < maxMu) {
                maxMu = mu[nkern - 1];
            }
        }

        KernelEstimator KL = new KernelEstimator(precision);
        for (int i = 0; i < cols.length - 1; ++i) {
            for (int j = i + 1; j < cols.length; ++j) {
                if (cmp.compare(df.getColumnName(cols[i]), df.getColumnName(cols[j])) == 0) {
                    double kl = kullbackLeibler(est[i], est[j], precision, maxMu);
                    System.out.printf("%s %s %f\n",df.getColumnName(cols[i]), df.getColumnName(cols[j]), kl);
                    KL.addValue(kl, 1.0);
                }
            }
        }
        XYSeriesCollection collect = new XYSeriesCollection();
        XYSeries series = new XYSeries("KL");
        int nkern = KL.getNumKernels();
        double[] mu = KL.getMeans();
        for (int i = 0; i < nkern; ++i) {
            double p = KL.getProbability(mu[i]);
            series.add(mu[i], p);
        }
        collect.addSeries(series);
        JFreeChart chart = ChartFactory.createXYLineChart("KL", "KL Divergence", "Density",
                collect, PlotOrientation.VERTICAL, true, false, false);

        return new ChartPanel(chart);        
    }

    public ChartPanel getKDE(boolean logit) {
        double precision = 0.001;
        List<KernelEstimator> estList = new ArrayList<>();
        DataFrame df = (DataFrame) this.getModel();
        int[] cols = this.getSelectedColumns();

        double maxMu = Double.MAX_VALUE;
        XYSeriesCollection collect = new XYSeriesCollection();
        for (int col : cols) {
            XYSeries series = new XYSeries(df.getColumnName(col));
            KernelEstimator est = new KernelEstimator(precision);
            estList.add(est);
            double[] x = df.getColumnAsDouble(col);
            for (double v : x) {
                if (logit) {
                    est.addValue(Math.log(1.0 + v), 1.0);
                } else {
                    est.addValue(v, 1.0);
                }
            }
            int nkern = est.getNumKernels();
            double[] mu = est.getMeans();
            if (mu[nkern - 1] < maxMu) {
                maxMu = mu[nkern - 1];
            }
            for (int i = 0; i < nkern; ++i) {
                double p = est.getProbability(mu[i]);
                series.add(mu[i], p);
            }
            collect.addSeries(series);
        }

        double kld = 0.0;
        int count = 0;
        for (int i = 0; i < estList.size() - 1; ++i) {
            for (int j = i + 1; j < estList.size(); ++j) {
                ++count;
                kld = kld + kullbackLeibler(estList.get(i), estList.get(j), precision, maxMu);
            }
        }
        kld = kld / count;
        String title = String.format("Average Kullback-Leibler divergence=%.3f", kld);
        JFreeChart chart = ChartFactory.createXYLineChart(title, "log(TPM)", "Gene Density",
                collect, PlotOrientation.VERTICAL, true, false, false);

        return new ChartPanel(chart);
    }

    private double kullbackLeibler(KernelEstimator est1, KernelEstimator est2, double precision, double max) {
        double e = 0.0;
        for (double x = 0.0; x < max; x = x + precision) {
            double p1 = est1.getProbability(x);
            double p2 = est2.getProbability(x);
            e = e - p1 * Math.log(p2 / p1);
        }
        return e;
    }

    // return a histogram panel for the selected column
    public ChartPanel getHistogramPanel(String xLabel, int nBins, boolean logit) {
        DataFrame df = (DataFrame) this.getModel();
        int selCol = this.getSelectedColumn();
        String title = df.getColumnName(selCol);
        double[] x = df.getColumnAsDouble(selCol);
        if (logit) {
            for (int i = 0; i < x.length; ++i) {
                x[i] = Math.log10(x[i] + 1.0);
            }
        }
        HistogramDataset ds = new HistogramDataset();
        ds.addSeries(title, x, nBins);
        JFreeChart chart = ChartFactory.createHistogram(title, xLabel, "Count", ds, PlotOrientation.VERTICAL, false, false, false);
        return new ChartPanel(chart);
    }

    public ChartPanel getCummDistPanel(String xLabel, int nBins, boolean logit, double xAxisMax) {
        DataFrame df = (DataFrame) this.getModel();
        int[] selCols = this.getSelectedColumns();
        StringBuilder builder = new StringBuilder();
        builder.append(df.getColumnName(selCols[0]));
        for (int i = 1; i < selCols.length; ++i) {
            builder.append("|");
            builder.append(df.getColumnName(selCols[i]));
        }
//        String title = builder.toString();
        double[][] x = df.getAsDouble(selCols);
        double xMin = Double.MAX_VALUE;
        double xMax = Double.MIN_VALUE;
        double[] xVal = new double[x.length];
        for (int i = 0; i < x.length; ++i) {
            xVal[i] = Double.MIN_VALUE;
            for (int j = 0; j < x[i].length; ++j) {
                if (logit) {
                    x[i][j] = Math.log10(x[i][j] + 1.0);

                }
                if (x[i][j] > xVal[i]) {
                    xVal[i] = x[i][j];
                }
                if (x[i][j] > xMax) {
                    xMax = x[i][j];
                }
                if (x[i][j] < xMin) {
                    xMin = x[i][j];
                }
            }
        }

        double s = xMin;
        double e = xMax;
        double d = (e - s) / nBins;
        int[] binCounts = new int[nBins];

        for (int i = 0; i < x.length; ++i) {
            int bin = Math.min((int) ((xVal[i] - s) / d), nBins - 1);
            binCounts[bin] = binCounts[bin] + 1;
        }

        // accumulate the bins
        for (int b = 1; b < nBins; ++b) {
            binCounts[b] = binCounts[b] + binCounts[b - 1];
        }

        // make the plot
        XYSeries series = new XYSeries(xLabel);
        double v = s + d / 2;
        for (int b = 0; b < nBins; ++b) {
            if (v < xAxisMax) {
                series.add(v, binCounts[b]);
            }
            v = v + d;
        }
        XYSeriesCollection collect = new XYSeriesCollection(series);

        JFreeChart chart = ChartFactory.createXYLineChart("All Individual Samples", xLabel, "Number of Genes", collect, PlotOrientation.VERTICAL, false, false, false);
        return new ChartPanel(chart);
    }
}
