/*
 * To change this license header, choose License Headers in Project Properties.
 * To change this template file, choose Tools | Templates
 * and open the template in the editor.
 */
package org.rhwlab.LMS.dataframe;

import java.awt.Color;
import java.io.PrintStream;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.TreeMap;
import java.util.TreeSet;
import java.util.function.Predicate;
import javax.swing.JFrame;
import javax.swing.table.AbstractTableModel;
import org.apache.commons.math3.ml.clustering.Clusterable;
import org.math.plot.FrameView;
import org.math.plot.Plot2DPanel;

/**
 *
 * @author gevirl
 */
// DataFrame is a rectangular array of data (rows and columns), columns can be indexed
// DataFrame can have a column which is used to identify the rows - should have unique values
public class DataFrame extends AbstractTableModel {
    protected ArrayList<ArrayList<String>> frame = new ArrayList<>();
    protected ArrayList<String> headers = new ArrayList<>();
    protected TreeMap<String,TreeMap<String,List<Integer>>> indexes = new TreeMap<>();  //map of column indexes columnName,value -> row numbers
    protected Integer rowNameCol=null;  // column identifying the row
    protected String name;
    
    // transpose the dataframe
    // rownames becomes the headers, and headers become the rownames
    public void transpose(){
        DataFrame xpose = new DataFrame();
        for (int i=0 ; i<headers.size() ; ++i){
            xpose.frame.add(new ArrayList<>());
        }
        if (rowNameCol == null){
            for (int i=0 ; i<frame.size() ; ++i){
                ArrayList<String> row = this.getRow(i);
                xpose.addColumn(row,String.format("C%s", i));
            }
        } else {
            for (int i=0 ; i<frame.size() ; ++i){
                ArrayList<String> row = this.getRow(i);
                xpose.addColumn(row,row.get(rowNameCol));
            } 
            xpose.removeRow(rowNameCol);
        }

        indexes.clear();
        xpose.addColumn(headers,"RowID");
        
        rowNameCol = frame.size();
        this.frame = xpose.frame;
        this.headers = xpose.headers;
        this.fireTableStructureChanged();
        
        int uishdfuishdf=0;
    }
    public String toString(){
        return name;
    }
    public void setName(String name){
        this.name = name;
    }
    public String getName(){
        return this.name;
    }
    public void setRowNameColumn(int col){
        this.rowNameCol = col;
    }
    
    public int getRowNameColumn(){
        return this.rowNameCol;
    }
    
    @Override
    public int getRowCount() {
        return frame.size();
    }

    @Override
    public int getColumnCount() {
        return headers.size();
    }

    @Override
    public Object getValueAt(int rowIndex, int columnIndex) {
        String v = frame.get(rowIndex).get(columnIndex);
        try {
            return Double.valueOf(v);
        } catch (Exception exc){
            return v;
        }
    }
    @Override
    public Class getColumnClass(int c) {

        return getValueAt(0, c).getClass();
    }
    public String getColumnName(int col){
        return headers.get(col);
    }
    // average together the specified columns
    public double[][] getAvergaged(int[][] columns){
        double[][] ret = new double[frame.size()][];
        int r= 0;
        for (ArrayList<String> rowData : frame){
            ret[r] = new double[columns.length];
            for (int c=0 ; c<columns.length ; ++c){
                ret[r][c] = 0.0;
                for (int i=0 ; i<columns[c].length ; ++i){
                    ret[r][c] = ret[r][c] + Double.valueOf(rowData.get(columns[c][i]));
                }
                ret[r][c] = ret[r][c]/columns[c].length;
            }
        }
        return ret;
    }
    // remove rows which do not have a value greater than the given value
    public void removeRowsByMax(String first,String last,double maxValue){
        removeRowsByMax(column(first),column(last),maxValue);
    }
    public void removeRowsByMax(int first,int last,double maxValue){
        removeRowsByMax(columns(first,last),maxValue);
    }
    public void removeRowsByMax(int[] columns,double maxValue){
        
        frame.removeIf(new Predicate(){
            @Override
            public boolean test(Object t) {
                double max = Double.MIN_VALUE;
                ArrayList<String> rowData = (ArrayList<String>)t;
                for (int c=0 ; c<columns.length ; ++c){
                    String s = rowData.get(columns[c]);
                    double v = Double.valueOf(s);
                    if (v > max){
                        max = v;
                    }
                }
                return max < maxValue;
            }
        });
    }
    public void removeRow(int i){
        frame.remove(i);
    }
    public void setColumn(String value,int column){
        for (ArrayList<String> rowData : frame){
            rowData.set(column, value);
        }
    }
    // find the first instance of the value
    public Integer find(String columnName,String value){
        if (value == null){
            return null;
        }
        TreeMap<String,List<Integer>> index = indexes.get(columnName);
        if (index == null){
            index = indexColumn(columnName);
        }
        List<Integer> ret = index.get(value);
        if (ret != null){
            return ret.get(0);
        }
        return null;
    }
    public List<Integer> findAll(String columnName,String value){
        TreeMap<String,List<Integer>> index = indexes.get(columnName);
        if (index == null){
            index = indexColumn(columnName);
        }
        return index.get(value);
    }
    public TreeMap<String,List<Integer>> indexColumn(String columnName){
        int c = column(columnName);
        
        TreeMap<String,List<Integer>> index = this.indexes.get(columnName);
        if (index == null){
            index = new TreeMap<>();
            indexes.put(columnName,index);
        }
        
        int r=0;
        for (ArrayList<String> rowData : frame){
            String value = rowData.get(c);
            List<Integer> list = index.get(value);
            if (list == null){
                list = new ArrayList<>();
                index.put(value, list);
            }
            list.add(r);
            ++r;
        }
        return index;
    }
    public ArrayList<String> getRow(int i){
        return frame.get(i);
    }
    public double[] getRowAsDouble(int r,int[] cols){
        ArrayList<String> rowData = this.getRow(r);
        double[] ret = new double[cols.length];
        for (int c=0 ; c<cols.length ; ++c){
            ret[c] = Double.valueOf(rowData.get(cols[c]));
        }
        return ret;
    }
    public Set<String> columnUniqueValues(String head){
        return columnUniqueValues(column(head));
    }
    public Set<String> columnUniqueValues(int c){
        TreeSet<String> ret = new TreeSet<>();
        for (ArrayList<String> rowData : frame){
            String v = rowData.get(c);
            if (v != null && !v.equals("null")){
                ret.add(v);
            }
        }
        return ret;
    }
    public List<String> getHeaders(){
        return headers;
    }
    public int getSize(){
        return frame.size();
    }
    public void addConstantColumn(String value,String head){
        int i = 0;
        for (ArrayList<String> rowData : frame){
            rowData.add(value);
            ++i;
        }
        addHeader(head);        
    }
    public void replaceColumn(String[] data,int col){
        int i=0;
        for (ArrayList<String> rowData : frame){
            rowData.set(col,data[i]);
            ++i;
        }
    }
    public int removeColumn(String colName){
        int c = column(colName);
        if (c != -1){
            for (ArrayList<String> rowData : frame){
                rowData.remove(c);
            }
            this.headers.remove(c);
        }
        return c;
    }
   
    public int addColumn(List<String> rowAsList,String head){
        return addColumn(rowAsList.toArray(new String[0]),head);
    }
    public int addColumn(String[] data,String head){
        int ret = 1;
        if (!frame.isEmpty()){
            ret = frame.get(0).size();
        }        
        int i = 0;
        for (ArrayList<String> rowData : frame){
            rowData.add(data[i]);
            ++i;
        }
        addHeader(head);        
        return ret;
    }
    public int addColumn(double[] data,String head){
        int ret = 1;
        if (frame.get(0) != null){
            ret = frame.get(0).size();
        }
        int i = 0;
        for (ArrayList<String> rowData : frame){
            rowData.add(String.format("%.6f",data[i]));
            ++i;
        }
        addHeader(head); 
        return ret;
    }  
    public int[] addColumns(jsat.DataSet jsatDS,String[] heads){
        int[] ret = new int[jsatDS.getNumNumericalVars()];
        for (int c=0 ; c<jsatDS.getNumNumericalVars() ; ++c){
            ret[c] = this.addColumn(jsatDS.getNumericColumn(c).arrayCopy(),heads[c]);
        }
        return ret;
    }
    public int[] addColumns(double[][] data,String[] heads){
        int[] ret = new int[heads.length];
        int i = 0;
        for (ArrayList<String> rowData : frame){
            int j=0;
            for (double v : data[i]){
                if (i==0){
                    ret[j]  = rowData.size();
                    ++j;
                }
                rowData.add(Double.toString(v));
            }
             
            ++i;
        }
        setHeaders(heads);
        return ret;
    }

    // replace all the headers
    public void replaceHeaders(String[] heads){
        this.headers.clear();
        for (String head : heads){
            addHeader(head);
        }
    }
    // adds a header - increases the number of columns
    public Integer addHeader(String head){
        Set<String> headSet = new HashSet<>(headers);
        if (!headSet.contains(head)){
            headers.add(head);
            return headers.size()-1;
        }
        return null;
    }
    public Integer[] setHeaders(String[] heads){
        Integer[] ret = new Integer[heads.length];
        for (int i=0 ; i<heads.length ; ++i){
            ret[i] = addHeader(heads[i]);
        }
        return ret;
    }
    public double[][] getAsDouble(int first,int last){

        return getAsDouble(columns(first,last));
    }
    static public int[] columns(int first,int last){
        int[] columns = new int[last-first+1];
        int i=0;
        for (int c=first ; c<=last ; ++c){
            columns[i] = c;
            ++i;
        }
        return columns;        
    }
    public double[][] getAsDouble(int[] columns){
        double[][] ret = new double[frame.size()][];
        int r= 0;
        for (ArrayList<String> rowData : frame){
            ret[r] = new double[columns.length];
            for (int i=0 ; i<columns.length ; ++i){
                ret[r][i] = Double.valueOf(rowData.get(columns[i]));
            }
            ++r;
        }
        return ret;
    }
    public int column(String head){
        for (int i=0 ; i<headers.size();++i){
            if (headers.get(i).equals(head)){
                return i;
            }
        }
        return -1;
    }
    public String[] getColumn(int i){
        String[] ret = new String[frame.size()];
        int r = 0;
        for (ArrayList<String> rowData : frame){
            ret[r] = rowData.get(i);
            ++r;
        }
        return ret;
    }
    public double[] getColumnAsDouble(int i){
        double[] ret = new double[frame.size()];
        int r = 0;
        for (ArrayList<String> rowData : frame){
            ret[r] = Double.valueOf(rowData.get(i));
            ++r;
        }
        return ret;        
    }
    // save the entire data frame
    public void save(PrintStream stream)throws Exception {
        int[] rows = new int[frame.size()];
        for (int r=0 ; r<rows.length ; ++r){
            rows[r] = r;
        }
        saveRows(rows,stream);
    }
    // save a subset of rows
    public void saveRows(int[] rows,PrintStream stream)throws Exception {
        // save the headers
        if (!headers.isEmpty()){
            boolean first = true;
            for (String header : headers){
                if (!first){
                    stream.print(",");
                }
                stream.print(header);
                first = false;
            }
            stream.println();
        }
        for (int row : rows){
            boolean first = true;
            for (String data : frame.get(row)){
                if (!first){
                    stream.print(",");
                }
                stream.print(data);
                first = false;
            }
            stream.println();
        }
    }
    public int[] getRowsMatching(int col,String value){
        ArrayList<Integer> matching = new ArrayList<>();
        int i = 0;
        for (ArrayList<String> rowData : frame){
            if (col < 0){
                matching.add(i);
            }else {
                String data = rowData.get(col);
                if (data!= null){
                    if (data.equals(value)){
                        matching.add(i);
                    }
                }
            }
            ++i;
        }
        int[] ret = new int[matching.size()];
        int j=0;
        for (Integer m : matching){
            ret[j] = m;
            ++j;
        }
        return ret;
    }

    public void scatterPlot(int xCol,int yCol)throws Exception {

        
        Plot2DPanel plot = new Plot2DPanel();
        int j=0;
            int i=0;
            double[][] XY = new double[frame.size()][];
            for (ArrayList<String> rowData : frame){
                XY[i] = new double[2];
                XY[i][0] = Double.valueOf(rowData.get(xCol));
                XY[i][1] = Double.valueOf(rowData.get(yCol));
                ++i;                
            }
            
            plot.addScatterPlot("all",Color.black, XY);
        
      
        plot.plotCanvas.setNotable(true);
        plot.plotCanvas.setNoteCoords(true);
                
        plot.addLegend("SOUTH");
        
        FrameView plotframe = new FrameView(plot);
        plotframe.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
        plotframe.setVisible(true);         
    }    
    
    
    
    public class ClusterableFromColumns implements Clusterable{
        double[] x;
        int row;
        
        public ClusterableFromColumns(int[] columns,ArrayList<String> rowData,int row){
            this.row = row;
            x = new double[columns.length];
            for (int i=0 ; i<x.length ; ++i){
                x[i] = Double.valueOf(rowData.get(columns[i]));
            }
        }
        @Override
        public double[] getPoint() {
            return x;
        }
        public int getSourceRow(){
            return row;
        }

    }
}
