/*
 * To change this license header, choose License Headers in Project Properties.
 * To change this template file, choose Tools | Templates
 * and open the template in the editor.
 */
package org.rhwlab.LMS.dataframe;

import java.awt.Color;
import java.awt.event.ActionEvent;
import java.awt.event.ActionListener;
import java.awt.event.ItemEvent;
import java.awt.event.ItemListener;
import java.util.List;
import java.util.TreeMap;
import javax.swing.BorderFactory;
import javax.swing.Box;
import javax.swing.BoxLayout;
import javax.swing.JButton;
import javax.swing.JCheckBox;
import javax.swing.JComboBox;
import javax.swing.JLabel;
import javax.swing.JPanel;
import javax.swing.JTextField;
import org.apache.commons.math3.ml.distance.DistanceMeasure;
import org.apache.commons.math3.ml.distance.EuclideanDistance;
import org.rhwlab.LMS.dataframe.AngleMeasure;
import org.rhwlab.LMS.dataframe.DBSCAN;
import org.rhwlab.LMS.dataframe.DataFrame;
import org.rhwlab.LMS.dataframe.DensityPeakCluster;
import org.rhwlab.LMS.dataframe.JensenShannonMetric;
import org.rhwlab.LMS.dataframe.Kmeans;
import org.rhwlab.LMS.dataframe.Spearman;

/**
 *
 * @author gevirl
 */
public class ClusteringDefinitionPanel extends JPanel implements ActionListener {
    ClusteringPanel parent;
    DataFrame df;
    JComboBox normalizationBox;
    JComboBox dataSetBox;
    JComboBox algorithmBox;
    JComboBox distanceBox;
    JTextField epsilonField = new JTextField();
    JTextField minPtsField = new JTextField();
    JTextField nClustersField = new JTextField();
    JTextField cutoffField = new JTextField();
    JTextField densityThreshField = new JTextField();
    JTextField distThreshField = new JTextField();
    JTextField minClusterSizeField = new JTextField();
    JTextField minDistanceField = new JTextField();
    
    JPanel kmeansPanel = new JPanel();
    JPanel dbscanPanel = new JPanel();
    JPanel dpcPanel = new JPanel();
    JPanel parametersPanel = new JPanel();
    JPanel qtPanel = new JPanel();
    
    JButton clusterButton = new JButton("Cluster");
    
    String[] dataSets = {"Raw Samples","Rep Averaged","T-sne","T-sne(Avg)","PCA"};
    String[] algorithms = {"Kmeans","DBSCAN","DPC","QT"};
    String[] distances = {"Euclidian","JS metric","Angle","Spearmans"};   
    String[] normalizations = {"None","Sum=1","Length=1"};
    
    public ClusteringDefinitionPanel(DataFrame df,ClusteringPanel parent){
        this.setBorder(BorderFactory.createLineBorder(Color.black, 3));
        this.parent = parent;
        this.df = df;
        this.setLayout(new BoxLayout(this,BoxLayout.Y_AXIS));
        dataSetBox = new JComboBox(dataSets);
        algorithmBox = new JComboBox(algorithms);
        distanceBox = new JComboBox(distances);
        normalizationBox = new JComboBox(normalizations);
        
        this.add(Box.createVerticalStrut(10));
        this.add(new JLabel("DataSet"));
        this.add(dataSetBox);
        this.add(Box.createVerticalStrut(10));
        
        this.add(new JLabel("Normalization"));
        this.add(normalizationBox);
        this.add(Box.createVerticalStrut(10));
        
        this.add(new JLabel("Algorithm"));
        this.add(algorithmBox);
        this.add(Box.createVerticalStrut(10));

        algorithmBox.addItemListener(new ItemListener(){
            @Override
            public void itemStateChanged(ItemEvent e) {
                parametersPanel.removeAll();
                String item = (String)e.getItem();
                if (item.equals("Kmeans")){
                    parametersPanel.add(kmeansPanel);
                }
                else if (item.equals("DBSCAN")){
                    parametersPanel.add(dbscanPanel);
                }
                else if (item.equals("DPC")){
                    parametersPanel.add(dpcPanel);
                } 
                else if (item.equals("QT")){
                    parametersPanel.add(qtPanel);
                }                 
                ClusteringDefinitionPanel.this.invalidate();
                ClusteringDefinitionPanel.this.parent.refresh();
            }
        });

        this.add(new JLabel("Distance"));
        this.add(distanceBox);
        this.add(Box.createVerticalStrut(10));
        
        dbscanPanel.setLayout(new BoxLayout(dbscanPanel,BoxLayout.Y_AXIS));
        dbscanPanel.add(Box.createVerticalStrut(10));
        dbscanPanel.add(new JLabel("Epsilon"));
        dbscanPanel.add(epsilonField);
        dbscanPanel.add(Box.createVerticalStrut(10));
        dbscanPanel.add(new JLabel("minPts"));
        dbscanPanel.add(minPtsField);
        dbscanPanel.add(Box.createVerticalStrut(10));
        
        kmeansPanel.setLayout(new BoxLayout(kmeansPanel,BoxLayout.Y_AXIS));
        kmeansPanel.add(Box.createVerticalStrut(10));
        kmeansPanel.add(new JLabel("# of Clusters"));
        kmeansPanel.add(nClustersField);
        kmeansPanel.add(Box.createVerticalStrut(10));
        
        dpcPanel.setLayout(new BoxLayout(dpcPanel,BoxLayout.Y_AXIS));
        dpcPanel.add(Box.createVerticalStrut(10));
        dpcPanel.add(new JLabel("Cutoff"));
        dpcPanel.add(cutoffField);
        dpcPanel.add(Box.createVerticalStrut(10));
        dpcPanel.add(new JLabel("DensityThreshold"));
        dpcPanel.add(densityThreshField);
        dpcPanel.add(Box.createVerticalStrut(10));   
        dpcPanel.add(new JLabel("DistanceThreshold"));
        dpcPanel.add(distThreshField);
        dpcPanel.add(Box.createVerticalStrut(10)); 
        
        qtPanel.setLayout(new BoxLayout(qtPanel,BoxLayout.Y_AXIS));
        qtPanel.add(Box.createVerticalStrut(10));
//        qtPanel.add(new JLabel("MinimumClusterSize"));
//       qtPanel.add(minClusterSizeField);
//        qtPanel.add(Box.createVerticalStrut(10));   
        qtPanel.add(new JLabel("MinimumDistance"));
        qtPanel.add(minDistanceField);
        qtPanel.add(Box.createVerticalStrut(10)); 
        
        parametersPanel.add(kmeansPanel);
        
        this.add(parametersPanel);
        this.add(clusterButton);
        clusterButton.addActionListener(this);
    }
    
    @Override
    public String toString(){
        String ret = String.format("%s ; %s ; %s ; %s",
                dataSetBox.getSelectedItem(),normalizationBox.getSelectedItem(),algorithmBox.getSelectedItem(),distanceBox.getSelectedItem());
        
        if (algorithmBox.getSelectedItem().equals("DBSCAN")){
        ret = String.format("%s ; Epsilon: %s ; minPts: %s" , 
                ret,epsilonField.getText().trim(),minPtsField.getText().trim());
        }
        else if (algorithmBox.getSelectedItem().equals("Kmeans")){
        ret = String.format("%s ; # of Clusters: %s",
                ret,nClustersField.getText().trim());
        } 
        else if (algorithmBox.getSelectedItem().equals("DPC")){
        ret = String.format("%s ; Cutoff: %s : Rho: %s ; Delta: %s",
                ret,cutoffField.getText().trim(),densityThreshField.getText().trim(),distThreshField.getText().trim()); 
        }
        else if (algorithmBox.getSelectedItem().equals("QT")){
        ret = String.format("%s ; MinCluster: %s ; MinDistance: %s ",
                ret,minClusterSizeField.getText().trim(),minDistanceField.getText().trim()); 
        }        
        return ret;
    }

    public void parse(String s){
        if (s.charAt(0)=='#'){
            s = s.substring(1);
        }
        String[] tokens = s.split(";");
        TreeMap<String,String> map = new TreeMap<>();
        for (String token : tokens){
            String key = token.split(":")[0].trim();
            String value = token.split(":")[1].trim();
            map.put(key,value);
        }
        
        if (map.get("Epsilon") != null)
            epsilonField.setText(map.get("Epsilon"));
        if (map.get("minPts") != null)
            minPtsField.setText(map.get("minPts"));
        nClustersField.setText(map.get("# of Clusters"));
        dataSetBox.setSelectedItem(map.get("DataSet"));
        algorithmBox.setSelectedItem(map.get("Algorithm"));
        distanceBox.setSelectedItem("Distance");
    }

    // do the clustering
    @Override
    public void actionPerformed(ActionEvent e) {
        int[] columns =  null;
        String[] labels = null;
        
        double[][] x = null;
        if (dataSetBox.getSelectedItem().equals("T-sne")){
            columns = new int[2];
            columns[0] = df.column("TsneX");
            columns[1] = df.column("TsneY");
        } 
        else if (dataSetBox.getSelectedItem().equals("Raw Samples") ){
            int firstCol = df.column("RiboZeroBitSeq1.16.0_ceh32rep2T0_1_1_Merged1");
            int lastCol = df.column("RiboZeroBitSeq1.16.0_tbx37repT4_1_2_Merged1");
            columns = new int[lastCol - firstCol +1];
            for (int i=0 ; i<columns.length ; ++i){
                columns[i] = firstCol+i;
            } 
        }
        else if (dataSetBox.getSelectedItem().equals("Rep Averaged") ){
            int firstCol = df.column("ceh32T0");
            int lastCol = df.column("tbx37T4");
            columns = new int[lastCol - firstCol +1];
            for (int i=0 ; i<columns.length ; ++i){
                columns[i] = firstCol+i;
            } 
        }        
        else if (dataSetBox.getSelectedItem().equals("PCA")){
            columns = new int[2];
            columns[0] = df.column("PCA1");
            columns[1] = df.column("PCA2");
        }
        else if (dataSetBox.getSelectedItem().equals("T-sne(Avg)")){
            columns = new int[2];
            columns[0] = df.column("AvgTsneX");
            columns[1] = df.column("AvgTsneY");
        } 
        
        x = df.getAsDouble(columns);
        if (normalizationBox.getSelectedItem().equals("Sum=1")){
            normalizeSum(x);
        }
        else if (normalizationBox.getSelectedItem().equals("Length=1")){
            normalizeLength(x);
        }        
        
        DistanceMeasure distMeasure = null;
        if (distanceBox.getSelectedItem().equals("Euclidian")){
            distMeasure = new EuclideanDistance();
        } 
        else if (distanceBox.getSelectedItem().equals("JS metric")){
            distMeasure = new JensenShannonMetric();
        }
        else if (distanceBox.getSelectedItem().equals("Angle")){
            distMeasure = new AngleMeasure();
        }        
        else if (distanceBox.getSelectedItem().equals("Spearmans")){
            distMeasure = new Spearman();
        }         
        if (algorithmBox.getSelectedItem().equals("DBSCAN")){
            double eps = Double.valueOf(epsilonField.getText().trim());
            int min = Integer.valueOf(minPtsField.getText().trim());
            DBSCAN dbscan = new DBSCAN(df);
            int n = dbscan.cluster(eps, min, x,distMeasure);  
            labels = dbscan.getClusterLabels(); 
            
        }
        else if (algorithmBox.getSelectedItem().equals("Kmeans")){
            Kmeans kmeans = new Kmeans(df);
            kmeans.cluster(Integer.valueOf(nClustersField.getText().trim()),x,distMeasure);
            labels = kmeans.getClusterLabels(); 
        }
        else if (algorithmBox.getSelectedItem().equals("DPC")){
            DensityPeakCluster dcp = new DensityPeakCluster(df);
            double cutoff = Double.valueOf(cutoffField.getText().trim());
            double rho = Double.valueOf(densityThreshField.getText().trim());
            double delta = Double.valueOf(distThreshField.getText().trim());
//            dcp.calcRhoDelta(x, distMeasure);
            
//            dcp.decisionPlot();
 
            dcp.cluster(cutoff ,delta, rho, x, distMeasure);
            labels = dcp.getClusterLabels();
            
            int auisdhfiusdh=0;
        }

        
        // add a new column to the df
        String s = this.toString();
        df.addColumn(labels,Integer.toString(s.hashCode()));
        parent.addCluster(s);
        
        
    }
    public void setDataframe(DataFrame df){
        this.df = df;
    }
    static public void normalizeSum(double[][] x){
        for (int r=0; r<x.length ; ++r){
            double sum = 0.0;
            for (int c=0 ; c<x[r].length ; ++c){
                sum = sum + x[r][c];
            }
            for (int c=0 ; c<x[r].length ; ++c){
                x[r][c] = x[r][c]/sum;
            }            
        }
        
    }
    static public void normalizeLength(double[][] x){
        for (int r=0; r<x.length ; ++r){
            double sumsqr = 0.0;
            for (int c=0 ; c<x[r].length ; ++c){
                sumsqr = sumsqr + x[r][c]*x[r][c];
            }
            sumsqr = Math.sqrt(sumsqr);
            for (int c=0 ; c<x[r].length ; ++c){
                x[r][c] = x[r][c]/sumsqr;
            }            
        }
        
    }    
}
