/*
 * To change this license header, choose License Headers in Project Properties.
 * To change this template file, choose Tools | Templates
 * and open the template in the editor.
 */
package org.rhwlab.LMS.dataframe;

import java.awt.BorderLayout;
import java.awt.Color;
import java.awt.Dimension;
import java.awt.event.ActionEvent;
import java.awt.event.ActionListener;
import java.awt.event.MouseEvent;
import java.io.File;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.Set;
import java.util.TreeSet;
import javax.swing.BorderFactory;
import javax.swing.Box;
import javax.swing.BoxLayout;
import javax.swing.DefaultListModel;
import javax.swing.JButton;
import javax.swing.JFileChooser;
import javax.swing.JFrame;
import javax.swing.JLabel;
import javax.swing.JList;
import javax.swing.JMenu;
import javax.swing.JMenuBar;
import javax.swing.JMenuItem;
import javax.swing.JPanel;
import javax.swing.JScrollPane;
import javax.swing.JTextField;
import javax.swing.event.ListSelectionEvent;
import javax.swing.event.ListSelectionListener;
import org.apache.commons.math3.ml.distance.EuclideanDistance;
import org.math.plot.Plot2DPanel;
import org.math.plot.plots.ColorBrewer;
import org.math.plot.plots.Plot;

/**
 *
 * @author gevirl
 */
public class TsneFrame extends JFrame {
    File dataFile;
    FileDataFrame df;
    
    JTextField episilon = new JTextField("1.8");
    JTextField minPts = new JTextField("50");
    JLabel nDensityClusters=new JLabel("");
    JLabel nGenes=new JLabel("");      
    JTextField searchField = new JTextField("");
    JTextField nKMeansClusters = new JTextField("");
    JList axisList;
    JList clusterList;
    JScrollPane clusterScroll;
    JButton searchButton = new JButton("Search");
    MouseNotifyingCanvas canvas;
    ListDisplayPanel geneDisplayPanel;
    int geneCol;
    String geneColName;
    String clusterColumnName;
    int xCol;
    int yCol;
    int densityCol;
    int kmeansCol;
    int rawData[];
    
    String[] axes = {"T-sne","T-sne(avg)","PCA"};
    
    JFileChooser fileChooser;
    ClusteringPanel clusteringPanel;
    
    public TsneFrame(FileDataFrame dframe,int[] rawData,String geneHead,String tsneCol,String kmeansColName){
        this.df = dframe;
        this.rawData = rawData;
        this.geneColName = geneHead;
        geneCol = df.column(geneHead);
        this.xCol = df.column("TsneX");
        this.yCol = df.column("TsneY");
        this.densityCol = df.column(tsneCol);
        this.kmeansCol = df.column(kmeansColName);
        this.getContentPane().setLayout(new BorderLayout());
        
        buildMenu();
        
        JPanel west = new JPanel();
        west.setLayout(new BoxLayout(west,BoxLayout.Y_AXIS));
        DefaultListModel model = new DefaultListModel();
        for (String axis : axes){
            model.addElement(axis);
        }
        axisList = new JList(model);
        axisList.setSelectedIndex(0);
        
        
        axisList.setBorder(BorderFactory.createLineBorder(Color.BLACK,4));
        axisList.addListSelectionListener(new ListSelectionListener(){
            @Override
            public void valueChanged(ListSelectionEvent e) {
                if (!e.getValueIsAdjusting()){
                    if (axisList.getSelectedValue().equals("T-sne")){
                        xCol = df.column("TsneX");
                        yCol = df.column("TsneY");
                    } 
                    else if (axisList.getSelectedValue().equals("PCA")){
                        xCol = df.column("PCA1");
                        yCol = df.column("PCA2");                        
                    }
                    else {
                        xCol = df.column("AvgTsneX");
                        yCol = df.column("AvgTsneY");                        
                    }                    
                    updateCanvas(clusterColumnName);
                }
            }
        });
        
        clusterList = new JList();
        clusterScroll = new JScrollPane(clusterList);
        clusterList.addListSelectionListener(new ListSelectionListener(){
            @Override
            public void valueChanged(ListSelectionEvent e) {
                if (!e.getValueIsAdjusting()){
                    clusterSelected();
                }
            }
        });
        
        west.add(Box.createVerticalStrut(25));
        west.add(axisList);
        west.add(Box.createVerticalStrut(25));
        west.add(clusterScroll);
        west.add(Box.createVerticalStrut(25));
        this.add(west,BorderLayout.WEST);
       
        clusteringPanel = new ClusteringPanel(df,this);
        clusteringPanel.setDataFrame(df);
        this.add(clusteringPanel,BorderLayout.SOUTH);
        
       
        // build the center panel
        canvas = new MouseNotifyingCanvas();
        Plot2DPanel plot = new Plot2DPanel(canvas);        
        plot.plotCanvas.setNotable(true);
        plot.plotCanvas.setNoteCoords(false);
        plot.addLegend("SOUTH");
        this.add(plot,BorderLayout.CENTER);
        
        // build the east panel
        String[] genes = df.getColumn(geneCol);
        geneDisplayPanel = new ListDisplayPanel(genes,df,geneColName,rawData);        
        this.add(geneDisplayPanel,BorderLayout.EAST);

        // build the searching panel
        JPanel search = new JPanel();
        search.setLayout(new BoxLayout(search,BoxLayout.X_AXIS));
        search.add(Box.createHorizontalStrut(200));
        search.add(new JLabel("Enter a gene name:  "));
        search.add(Box.createHorizontalGlue());
        search.add(searchField);
        search.add(Box.createHorizontalGlue());
        search.add(searchButton);
        search.add(Box.createHorizontalStrut(200));
        search.add(nGenes);
        search.add(Box.createHorizontalStrut(20));
        this.add(search,BorderLayout.NORTH);

        
        searchButton.addActionListener(new ActionListener(){
            @Override
            public void actionPerformed(ActionEvent e) {
                searchGene();
            }
        });
        
        updateCanvas(clusterColumnName);
    }
    final void buildMenu(){
        JMenuBar menuBar = new JMenuBar();
        this.setJMenuBar(menuBar); 
        
        JMenu fileMenu = new JMenu("File");
        menuBar.add(fileMenu);
        
        JMenuItem openItem = new JMenuItem("Open");
        fileMenu.add(openItem);
        openItem.addActionListener(new ActionListener(){
            @Override
            public void actionPerformed(ActionEvent e) {
                if (fileChooser == null){
                    if (dataFile == null){
                        fileChooser = new JFileChooser();
                    } else {
                        fileChooser = new JFileChooser(dataFile.getParentFile());
                    }
                }
                if (fileChooser.showOpenDialog(TsneFrame.this) == JFileChooser.APPROVE_OPTION){
                    try {
                        df = new FileDataFrame(fileChooser.getSelectedFile());
                        clusteringPanel.setDataFrame(df);
                        TsneFrame.this.add(clusteringPanel,BorderLayout.SOUTH); 
                        TsneFrame.this.repaint();
                    } catch (Exception exc){
                        exc.printStackTrace();
                    }
                } 
            }
        });
        
        JMenuItem saveItem = new JMenuItem("Save");
        fileMenu.add(saveItem);
        
        JMenuItem saveAsItem = new JMenuItem("SaveAs");
        fileMenu.add(saveAsItem);
        saveAsItem.addActionListener(new ActionListener(){
            @Override
            public void actionPerformed(ActionEvent e) {
                if (fileChooser == null){
                    if (dataFile == null){
                        fileChooser = new JFileChooser();
                    } else {
                        fileChooser = new JFileChooser(dataFile.getParentFile());
                    }
                }
                if (fileChooser.showSaveDialog(TsneFrame.this) == JFileChooser.APPROVE_OPTION){
                    try {
                        df.save(fileChooser.getSelectedFile());
                    } catch (Exception exc){
                        exc.printStackTrace();
                    }
                } 
            }
        });
        
        JMenuItem dpcPlot = new JMenuItem("Decision Plot");
        fileMenu.add(dpcPlot);
        dpcPlot.addActionListener(new ActionListener(){
            @Override
            public void actionPerformed(ActionEvent e) {
                int[] columns = new int[2];
                columns[0] = df.column("TsneX");
                columns[1] = df.column("TsneY");
                double[][] x = df.getAsDouble(columns);                
                DensityPeakCluster dcp = new DensityPeakCluster(df);
                dcp.calcRhoDelta(x, new EuclideanDistance());
                dcp.decisionPlot();
            }
        });
    }

    private void searchGene(){
        Integer row = df.find(geneColName, searchField.getText());
        if (row != null){
            List<String> rowData = df.getRow(row);
            double[] xy = new double[2];
            xy[0] = Double.valueOf(rowData.get(xCol));
            xy[1] = Double.valueOf(rowData.get(yCol));
            canvas.setCrossHair(xy);
            int[] proj = canvas.project(xy);
            canvas.mousePressed(new MouseEvent(this,MouseEvent.MOUSE_PRESSED,System.currentTimeMillis(),MouseEvent.BUTTON1_MASK,proj[0],proj[1],1,false));
            canvas.mouseReleased(new MouseEvent(this,MouseEvent.MOUSE_RELEASED,System.currentTimeMillis(),MouseEvent.BUTTON1_MASK,proj[0],proj[1],1,false));
            canvas.mouseClicked(new MouseEvent(this,MouseEvent.MOUSE_CLICKED,System.currentTimeMillis(),MouseEvent.BUTTON1_MASK,proj[0],proj[1],1,false));
          
            
        }
    }
    private void clusterSelected(){
        String cl = (String)clusterList.getSelectedValue();
        if (cl == null) return;
        String label  = cl.substring(0, cl.indexOf(" "));
        for (Plot plot : canvas.getPlots()){
            if (plot.getName().equals(label)){
                int[] proj = null;
                double[][] x = plot.getData();
                for (int i=0 ; i<x.length ; ++i){
                    proj = canvas.project(x[i]);
                    if (canvas.selectedPlots(proj).size() == 1){
                        break;
                    }
                }
                canvas.mousePressed(new MouseEvent(this,MouseEvent.MOUSE_PRESSED,System.currentTimeMillis(),MouseEvent.BUTTON1_MASK,proj[0],proj[1],1,false));
                canvas.mouseReleased(new MouseEvent(this,MouseEvent.MOUSE_RELEASED,System.currentTimeMillis(),MouseEvent.BUTTON1_MASK,proj[0],proj[1],1,false));
                canvas.mouseClicked(new MouseEvent(this,MouseEvent.MOUSE_CLICKED,System.currentTimeMillis(),MouseEvent.BUTTON1_MASK,proj[0],proj[1],1,false));                
            }
        }
    }

    public void updateCanvas(String columnName){
        clusterColumnName = columnName;
        canvas.removeAllPlots();
        
        // make a set of cluster labels
        Set<String> clusterLabels = null;
        if (columnName != null){
            clusterLabels = df.columnUniqueValues(columnName);
        } else {
            clusterLabels = new TreeSet<>();
            clusterLabels.add("All");
        }
        ArrayList<String> cList = new ArrayList(clusterLabels);
        cList.sort(new Comparator(){
            @Override
            public int compare(Object o1, Object o2) {
                Integer i1 = Integer.valueOf((String)o1);
                Integer i2 = Integer.valueOf((String)o2);
                return i1.compareTo(i2);
            }
        });
        ColorBrewer[] brewers = ColorBrewer.getDivergingColorPalettes(true);
        brewers = ColorBrewer.getQualitativeColorPalettes(true);
        Color[] colors = brewers[0].getColorPalette(clusterLabels.size());
        
        int g =0;
        int col = -1;
        if (columnName != null){
            col = df.column(columnName);
        }
        int c= 0;
        DefaultListModel model = new DefaultListModel();
        for (String label : cList) {
            int[] rows = df.getRowsMatching(col, label);
            model.addElement(String.format("%s - %d",label,rows.length));
            String[] labels = new String[rows.length];
            double[][] XY = new double[rows.length][];
            int j = 0;
            for (int r : rows) {
                ArrayList<String> rowData = df.getRow(r);
                XY[j] = new double[2];
                XY[j][0] = Double.valueOf(rowData.get(xCol));
                XY[j][1] = Double.valueOf(rowData.get(yCol));
                labels[j] = rowData.get(geneCol);
                ++j;
                ++g;
            }
            GeneScatterPlot scatter = new GeneScatterPlot(label, colors[c], XY, labels);
            ++c;
            scatter.addChangeListener(geneDisplayPanel);
            canvas.addPlot(scatter);
        } 
        this.clusterList.setModel(model);
        nGenes.setText(String.format("Total Genes in Clusters=%d",g));
    }
    
    static public void main(String[] args) throws Exception {
//        FileDataFrame df = new FileDataFrame(new File("/nfs/waterston/tsne-filtered.out"));
        FileDataFrame df = new FileDataFrame(new File("/nfs/waterston/FACS.mRNA.filtered.genes.pca.tsne.csv"));
        int firstCol = df.column("RiboZeroBitSeq1.16.0_ceh32rep2T0_1_1_Merged1");
        int lastCol = df.column("RiboZeroBitSeq1.16.0_tbx37repT4_1_2_Merged1");
        int[] cols = new int[lastCol - firstCol +1];
        for (int i=0 ; i<cols.length ; ++i){
            cols[i] = firstCol+i;
        }
        TsneFrame tsne = new TsneFrame(df,cols,"CommonName","DensityCluster","KMeansCluster");
        
        tsne.pack();
        tsne.setSize(900,900);
        tsne.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
        tsne.setVisible(true);         
    }
}
