/*
 * To change this license header, choose License Headers in Project Properties.
 * To change this template file, choose Tools | Templates
 * and open the template in the editor.
 */
package org.rhwlab.LMS.dataframe.coreg;

import htsjdk.samtools.util.Interval;
import htsjdk.samtools.util.IntervalTree;
import htsjdk.samtools.util.IntervalTreeMap;
import java.awt.Font;
import java.awt.geom.Ellipse2D;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.PrintStream;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import java.util.TreeMap;
import java.util.TreeSet;
import javax.swing.BoxLayout;
import javax.swing.JFrame;
import javax.swing.JPanel;
import javax.swing.JTabbedPane;
import org.apache.commons.math3.ml.distance.EuclideanDistance;
import org.apache.commons.math3.stat.inference.TTest;
import org.jfree.chart.ChartFactory;
import org.jfree.chart.ChartPanel;
import org.jfree.chart.JFreeChart;
import org.jfree.chart.axis.CategoryAxis;
import org.jfree.chart.axis.NumberAxis;
import org.jfree.chart.labels.BoxAndWhiskerToolTipGenerator;
import org.jfree.chart.plot.CategoryPlot;
import org.jfree.chart.plot.PlotOrientation;
import org.jfree.chart.plot.XYPlot;
import org.jfree.chart.renderer.category.BoxAndWhiskerRenderer;
import org.jfree.chart.renderer.xy.StandardXYBarPainter;
import org.jfree.chart.renderer.xy.XYBarRenderer;
import org.jfree.chart.renderer.xy.XYItemRenderer;
import org.jfree.data.statistics.BoxAndWhiskerCategoryDataset;
import org.jfree.data.statistics.DefaultBoxAndWhiskerCategoryDataset;
import org.jfree.data.statistics.HistogramDataset;
import org.jfree.data.xy.XYSeries;
import org.jfree.data.xy.XYSeriesCollection;
import org.rhwlab.LMS.dataframe.DataSet;
import org.rhwlab.LMS.dataframe.investigator.Investigation;
import org.rhwlab.gene.model.Annotation;
import org.rhwlab.gene.model.ModelFromGFF;

/**
 *
 * @author gevirl
 */
public class CoregulationFrame extends JFrame {

    String label;
    TreeMap<String, TreeSet<GenePair>> distanceMap;

    public CoregulationFrame(String label, TreeMap<String, TreeSet<GenePair>> distanceMap) {
        this.label = label;
        this.distanceMap = distanceMap;
    }

    public void asScatter() {

        JPanel mainPanel = new JPanel();
        mainPanel.setLayout(new BoxLayout(mainPanel, BoxLayout.Y_AXIS));
        for (String category : distanceMap.keySet()) {
//            if (!category.equals("Random")) {
            XYSeriesCollection collect = new XYSeriesCollection();
            XYSeries series = new XYSeries(category);
            collect.addSeries(series);
            TreeSet<GenePair> pairs = distanceMap.get(category);
            for (GenePair pair : pairs) {
                if (pair.genomicDistance > 0.0) {
                    series.add(Math.log10(pair.genomicDistance), pair.expDistance);
                }
            }

            JFreeChart chart = ChartFactory.createScatterPlot(category, "Log Genomic Distance", "Expression Distance", collect, PlotOrientation.VERTICAL, false, false, false);
            XYPlot plot = chart.getXYPlot();
            XYItemRenderer rend = plot.getRenderer();
            rend.setBaseShape(new Ellipse2D.Double(0.0, 0.0, 2.0, 2.0));
            rend.setSeriesShape(0, new Ellipse2D.Double(0.0, 0.0, 2.0, 2.0));
            ChartPanel panel = new ChartPanel(chart);
            mainPanel.add(panel);
            //         }

        }
        this.setContentPane(mainPanel);
    }

    public void asWhisker() {
        DefaultBoxAndWhiskerCategoryDataset expdataset = new DefaultBoxAndWhiskerCategoryDataset();
        DefaultBoxAndWhiskerCategoryDataset gendataset = new DefaultBoxAndWhiskerCategoryDataset();
        for (String category : distanceMap.keySet()) {
            List<Double> expvalues = new ArrayList<>();
            List<Integer> genomicValues = new ArrayList<>();
            for (GenePair pair : distanceMap.get(category)) {
                expvalues.add(pair.expDistance);
                genomicValues.add(pair.genomicDistance);
            }
            expdataset.add(expvalues, "Expression", category);
            if (!category.equals("Random")) {
                gendataset.add(genomicValues, "Genomic", category);
            }
        }
//        JFreeChart expchart = ChartFactory.createBoxAndWhiskerChart("Expression Distance", "Gene Orientation", "Normalized TPM", expdataset,false);
        JFreeChart expchart = makeWhiskerPanel("Expression Distance", expdataset, "Gene Orientation", "Normalized TPM");
//        JFreeChart genchart = ChartFactory.createBoxAndWhiskerChart("Genomic Distance", "Gene Orientation", "Bases", gendataset, false);   
        JFreeChart genchart = makeWhiskerPanel("Genomic Distance", gendataset, "Gene Orientation", "Bases");
        ChartPanel exppanel = new ChartPanel(expchart);
        ChartPanel genpanel = new ChartPanel(genchart);
        JPanel main = new JPanel();
        main.setLayout(new BoxLayout(main, BoxLayout.X_AXIS));
        main.add(exppanel);
        main.add(genpanel);
        this.setContentPane(main);
    }

    private JFreeChart makeWhiskerPanel(String title, BoxAndWhiskerCategoryDataset dataset, String xAxisLabel, String yAxisLabel) {
        final CategoryAxis xAxis = new CategoryAxis(xAxisLabel);
        final NumberAxis yAxis = new NumberAxis(yAxisLabel);
        yAxis.setAutoRangeIncludesZero(false);
        final BoxAndWhiskerRenderer renderer = new BoxAndWhiskerRenderer();
        renderer.setFillBox(false);
        renderer.setToolTipGenerator(new BoxAndWhiskerToolTipGenerator());
        final CategoryPlot plot = new CategoryPlot(dataset, xAxis, yAxis, renderer);

        final JFreeChart chart = new JFreeChart(
                title,
                new Font("SansSerif", Font.BOLD, 14),
                plot,
                false
        );
        return chart;
    }

    public void asHistograms() {
        JTabbedPane tabbed = new JTabbedPane();

        for (String key : distanceMap.keySet()) {
            TreeSet<GenePair> list = distanceMap.get(key);

            double sum = 0.0;
            HistogramDataset histDS = new HistogramDataset();
            double[] values = new double[list.size()];
            int i = 0;
            for (GenePair pair : list) {
                    values[i] = pair.expDistance;
                    sum = sum + values[i];
                    ++i;                

            }
            double mean = sum / values.length;
            histDS.addSeries(key, values, 200);

            JFreeChart chart = ChartFactory.createHistogram(
                    String.format("%s - mean=%.2f", label, mean),
                    null,
                    null,
                    histDS,
                    PlotOrientation.VERTICAL,
                    true,
                    true,
                    false);
            XYPlot plot = (XYPlot) chart.getPlot();
            plot.setDomainPannable(true);
            plot.setRangePannable(true);
            plot.setForegroundAlpha(0.85f);
            NumberAxis yAxis = (NumberAxis) plot.getRangeAxis();
            yAxis.setStandardTickUnits(NumberAxis.createIntegerTickUnits());
            XYBarRenderer renderer = (XYBarRenderer) plot.getRenderer();
            renderer.setDrawBarOutline(false);
            // flat bars look best...
            renderer.setBarPainter(new StandardXYBarPainter());
            renderer.setShadowVisible(false);
            ChartPanel panel = new ChartPanel(chart);
            tabbed.add(key, panel);

            sum = 0.0;
            histDS = new HistogramDataset();
            ArrayList<Double> valueList = new ArrayList<>();
            i = 0;
            for (GenePair pair : list) {
                if (pair.genomicDistance > 0.0){
                    double v = Math.log10(pair.genomicDistance);
                    valueList.add(v);
                    sum = sum + v;
                    ++i;
                    if (v >1.945 && v<=2.123){
                        if (pair.category.equals("HeadToTail-sp")){
                            pair.lowGene.print(System.out);
                            pair.highGene.print(System.out);
                            System.out.printf("Exp: %.3f   Genonic: %d  log: %.3f\n\n", pair.expDistance,pair.genomicDistance,v);
                        }
                    }
                }
            }
            values = new double[valueList.size()];
            for (int j=0 ; j<values.length ; ++j){
                values[j] = valueList.get(j);
            }
            mean = sum / values.length;
            histDS.addSeries(key, values, 200);

            chart = ChartFactory.createHistogram(
                    String.format("%s - mean=%.2f", label, mean),
                    null,
                    null,
                    histDS,
                    PlotOrientation.VERTICAL,
                    true,
                    true,
                    false);
            plot = (XYPlot) chart.getPlot();
            plot.setDomainPannable(true);
            plot.setRangePannable(true);
            plot.setForegroundAlpha(0.85f);
            yAxis = (NumberAxis) plot.getRangeAxis();
            yAxis.setStandardTickUnits(NumberAxis.createIntegerTickUnits());
            renderer = (XYBarRenderer) plot.getRenderer();
            renderer.setDrawBarOutline(false);
            // flat bars look best...
            renderer.setBarPainter(new StandardXYBarPainter());
            renderer.setShadowVisible(false);
            panel = new ChartPanel(chart);
            tabbed.add(key + "Genomic", panel);
        }
        setContentPane(tabbed);
    }

    static public TreeMap<String, TreeSet<GenePair>> distanceMap(String datasetName) throws Exception {
        TreeMap<String, String> operonWBGenes = new TreeMap<>();
        BufferedReader reader = new BufferedReader(new FileReader("/net/waterston/vol2/home/gevirl/OperonGenes.tsv"));
        String line = reader.readLine();
        while (line != null) {
            String[] tokens = line.split("\t");
            operonWBGenes.put(tokens[0], tokens[2]);
            line = reader.readLine();
        }
        reader.close();

        TreeMap<String, Integer> wbGenes = new TreeMap<>();
        ModelFromGFF gff3 = new ModelFromGFF(new File("/net/waterston/vol9/References/WS245/AllWormBase.withTransposon.gff3"));
        Investigation invest = new Investigation(new File("/net/waterston/vol2/home/gevirl/facsFuzzy.xml"));
        DataSet ds = invest.getDataSet(datasetName);
        TreeMap<String, IntervalTree> geneTrees = new TreeMap<>();

        // construct a complete interval tree for each chromosome
        for (String chromo : gff3.getChromosomes()) {
            IntervalTreeMap geneIntervalMap = gff3.asTree("gene", chromo, "biotype=protein_coding");
            IntervalTree geneIntervalTree = new IntervalTree();
            for (Object obj : geneIntervalMap.keySet()) {
                Interval geneInterval = (Interval) obj;
                Object geneAnnot = geneIntervalMap.get(geneInterval);
                geneIntervalTree.put(geneInterval.getStart(), geneInterval.getEnd(), geneAnnot);
            }
            geneTrees.put(chromo, geneIntervalTree);
        }

        // make the list of gene to use for analysis - don't use contained genes
        String[] rowNames = invest.getDataRowNames();
        for (int row = 0; row < rowNames.length; ++row) {
            String wbGene = gff3.getWBGene(rowNames[row]);
            Annotation annot = gff3.getGeneAnnotation(wbGene);
            String chr = annot.getChromosome();
            IntervalTree geneTree = geneTrees.get(chr);
            if (!containedGene(annot, geneTree)) {
                wbGenes.put(gff3.getWBGene(rowNames[row]), row);
            }
        }

        // remove any contained genes from the tree
        for (String chromo : gff3.getChromosomes()) {
            IntervalTree tree = geneTrees.get(chromo);
            // make a list of contained genes for this chromo
            ArrayList<Annotation> containedGenes = new ArrayList<>();
            Iterator iter = tree.iterator();
            while (iter.hasNext()) {
                Annotation annot = (Annotation) ((IntervalTree.Node) iter.next()).getValue();
                if (containedGene(annot, tree)) {
                    containedGenes.add(annot);
                }
            }
            for (Annotation annot : containedGenes) {
                tree.remove(annot.getStart(), annot.getEnd());
            }
        }

        EuclideanDistance dist = new EuclideanDistance();
        TreeMap<String, TreeSet<GenePair>> distanceMap = new TreeMap<>();
        distanceMap.put("HeadToTail", new TreeSet<>());
        distanceMap.put("HeadToHead", new TreeSet<>());
        distanceMap.put("TailToTail", new TreeSet<>());
        distanceMap.put("Operons", new TreeSet<>());

        for (String wbGene : wbGenes.keySet()) {
            Annotation annot = gff3.getGeneAnnotation(wbGene);
            String chr = annot.getChromosome();
            IntervalTree geneTree = geneTrees.get(chr);

            // find the upstream gene
            IntervalTree.Node upnode = null;
            Iterator<IntervalTree.Node> iter = geneTree.iterator(annot.getStart() + 1, annot.getStart() + 1);
            while (iter.hasNext()) {
                upnode = iter.next();
                if (upnode.getEnd() > annot.getEnd()) {
                    break;
                }
            }

            if (upnode != null) {
                Annotation upAnnot = (Annotation) upnode.getValue();
                //               if (!containedGene(upAnnot, geneTree)) {
                String up = classifyGenePair(annot, upAnnot, operonWBGenes);
                if (up != null) {
                    TreeSet<GenePair> list = distanceMap.get(up);
                    double[] x = ds.getRow(wbGenes.get(wbGene));
                    String upID = upAnnot.getGeneID();
                    Integer row = wbGenes.get(upID);
                    if (row != null) {
                        double[] y = ds.getRow(row);
                        double expDist = dist.compute(x, y);
                        int genomicDist = upAnnot.getStart() - annot.getEnd();
                        if (genomicDist < -5000) {
                            int asudhfuihsd = 0;
                        }
                        GenePair pair = new GenePair(annot, upAnnot, up, expDist, genomicDist);
                        list.add(pair);
                    }
                }
//                }
            }
            // find the downstream gene
            IntervalTree.Node downnode = null;
            iter = geneTree.reverseIterator(annot.getStart() - 1, annot.getStart() - 1);
            while (iter.hasNext()) {
                downnode = iter.next();
                if (downnode.getEnd() < annot.getEnd()) {
                    break;
                }
            }
            if (downnode != null) {
                Annotation downAnnot = (Annotation) downnode.getValue();
                //               if (!containedGene(downAnnot, geneTree)) {
                String down = classifyGenePair(downAnnot, annot, operonWBGenes);
                if (down != null) {
                    TreeSet<GenePair> list = distanceMap.get(down);
                    double[] x = ds.getRow(wbGenes.get(wbGene));
                    String downID = downAnnot.getGeneID();
                    Integer row = wbGenes.get(downID);
                    if (row != null) {
                        double[] y = ds.getRow(row);
                        double expDist = dist.compute(x, y);
                        int genomicDist = annot.getStart() - downAnnot.getEnd();
                        if (genomicDist < -5000) {
                            int huasdfhsuid = 0;
                        }
                        GenePair pair = new GenePair(downAnnot, annot, down, expDist, genomicDist);
                        
                        list.add(pair);
                    }
                }
                //              }
            }
        }
        // generate randon gene pairs
        Random rnd = new Random();
        TreeSet<GenePair> randomGenes = new TreeSet<>();
        while (randomGenes.size() < 10000) {
            int i = rnd.nextInt(rowNames.length);
            Annotation annot1 = gff3.getGeneAnnotation(gff3.getWBGene(rowNames[i]));
            int j = rnd.nextInt(rowNames.length);
            Annotation annot2 = gff3.getGeneAnnotation(gff3.getWBGene(rowNames[j]));
            while (i == j || !annot1.getChromosome().equals(annot2.getChromosome())) {
                j = rnd.nextInt(rowNames.length);
                annot2 = gff3.getGeneAnnotation(gff3.getWBGene(rowNames[j]));
            }
            String chr = annot1.getChromosome();
            IntervalTree geneTree = geneTrees.get(chr);
//            if (!containedGene(annot1, geneTree) && !containedGene(annot2, geneTree)) {
            double[] xi = ds.getRow(i);
            double[] xj = ds.getRow(j);
            double expDist = dist.compute(xi, xj);
            int gd = 0;
            if (annot1.getStart() > annot2.getEnd()) {
                gd = annot1.getStart() - annot2.getEnd();
            } else if (annot2.getStart() > annot1.getEnd()) {
                gd = annot2.getStart() - annot1.getEnd();
            }
            GenePair pair = new GenePair(annot1, annot2, "Random", expDist, gd);
            randomGenes.add(pair);
//            }

        }

        distanceMap.put("Random", randomGenes);

        return distanceMap;
    }

    static boolean containedGene(Annotation annot, IntervalTree tree) {
        // is the gene contained in a larger gene?
        Iterator overlapping = tree.overlappers(annot.getStart(), annot.getEnd());
        boolean contained = false;
        while (overlapping.hasNext()) {
            Annotation test = (Annotation) ((IntervalTree.Node) overlapping.next()).getValue();
            if (test.getStart() < annot.getStart() && test.getEnd() > annot.getEnd()) {
                contained = true;
//                System.out.printf("%s\n", annot.getGeneName());
                break;
            }
        }
        return contained;
    }

    static String classifyGenePair(Annotation down, Annotation up, TreeMap<String, String> operonMap) {
        if (down.getGeneID().contains("WBGene00000090")){
            int uisdfuis=0;
        }
        if (down != null && up != null) {
            // are both genes in the same operon?
            String downOp = operonMap.get(down.getGeneID());
            String upOp = operonMap.get(up.getGeneID());
            if (downOp != null && upOp != null && downOp.equals(upOp) ) {
                return "Operons";
            }
            if (down.getStrand().equals(up.getStrand())) {
/*                
                int dist = up.getStart() - down.getEnd();
                if (dist >=88 && dist <=132){
                    return "HeadToTail-sp";
                }
   */             
                return "HeadToTail";
            }
            if (down.getStrand().equals("-")) {
                return "HeadToHead";
            } else {
                return "TailToTail";
            }
        }
        return null;
    }

    static double[] toDouble(List<Double> list) {
        double[] ret = new double[list.size()];
        for (int i = 0; i < list.size(); ++i) {
            ret[i] = list.get(i);
        }
        return ret;
    }

    static void reportDistanceMap(TreeMap<String, TreeSet<GenePair>> map, PrintStream stream) {
        for (String category : map.keySet()) {
 //           if (!category.equals("Random")) {
                for (GenePair pair : map.get(category)) {
                    if (pair.genomicDistance > 0){
                    stream.printf("%s,%s,%s,%.3f,%d\n", pair.lowGene.getGeneName(), pair.highGene.getGeneName(), pair.category, pair.expDistance, (int) pair.genomicDistance);
                    }
                }
            }
  //      }

    }

    static public void main(String[] args) throws Exception {
        String[] dss = {"Max1AvgRepGeneTPM", "BHTsneMax1AvgRepGeneTPM-1_200_1000", "GeneTPM", "AvgRepGeneTPM"};
        String ds = dss[0];
        TreeMap<String, TreeSet<GenePair>> map = distanceMap(ds);

        PrintStream stream = new PrintStream("/net/waterston/vol2/home/gevirl/FACS/GenePairDistances.csv");
        CoregulationFrame.reportDistanceMap(map, stream);
        stream.close();

        CoregulationFrame frame = new CoregulationFrame(ds, map);
        frame.asHistograms();

        CoregulationFrame whiskerframe = new CoregulationFrame(ds, map);
        whiskerframe.asWhisker();

        TTest ttest = new TTest();
        String[] category = map.keySet().toArray(new String[0]);

        for (int i = 0; i < category.length - 1; ++i) {
            for (int j = i + 1; j < category.length; ++j) {
                double[] xi = GenePair.getExpDistance(map.get(category[i]));
                double[] xj = GenePair.getExpDistance(map.get(category[j]));
                double pValue = ttest.tTest(xi, xj);
                System.out.printf("%s  -  %s vs %s : pvalue=%g  nPairs=%d,%d\n", ds, category[i], category[j], pValue, xi.length, xj.length);
            }
        }
        frame.pack();
        frame.setSize(1200, 900);
        frame.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
        frame.setVisible(true);

        whiskerframe.pack();
        whiskerframe.setSize(300, 300);
        whiskerframe.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
        whiskerframe.setVisible(true);

        CoregulationFrame scatterFrame = new CoregulationFrame(ds, map);
        scatterFrame.asScatter();
        scatterFrame.pack();
        scatterFrame.setSize(500, 500);
        scatterFrame.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
        scatterFrame.setVisible(true);

    }
}
