/*
 * To change this license header, choose License Headers in Project Properties.
 * To change this template file, choose Tools | Templates
 * and open the template in the editor.
 */
package org.rhwlab.LMS.dataframe;

import java.io.File;
import java.io.PrintStream;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Set;
import java.util.TreeMap;
import jsat.SimpleDataSet;
import jsat.linear.DenseVector;
import org.apache.commons.math3.stat.StatUtils;
import org.jdom2.Element;

/**
 *
 * @author gevirl
 */
// a DataSet is a subset of DataFrame columns that have been given attributes
public class DataSet {

    String title;
    DataFrame df;
    int[] cols;  // the column numbers in the dataframe
    TreeMap<String, String[]> attributeMap = new TreeMap<>();  // attribute label -> array of column identifiers for each column

    public DataSet(Element ele, DataFrame frame) {
        this.df = frame;
        fromXML(ele);
    }

    public DataSet(String title, DataFrame frame, int[] columns, String[] columnAttributeKeys) {
        this.title = title;
        this.df = frame;
        this.cols = columns;

        // initialize the attribute map for any keys
        if (columnAttributeKeys != null) {

            for (String key : columnAttributeKeys) {
                attributeMap.put(key, new String[cols.length]);
            }
        }
    }

    public void saveAsDataFrame(File dir) throws Exception {
        saveAs(new File(dir, title + ".df"), null);
    }

    public void saveAs(File file, int[] rowID) throws Exception {
        PrintStream stream = new PrintStream(file);

        int[] labelCols = rowID;
        if (rowID == null) {
            // use the rowNameColumn from the dataframe 
            labelCols = new int[1];
            labelCols[0] = this.df.getRowNameColumn();

        }

        // print the data column heads
        String[] dfnames = getColumnNames();
        stream.printf("%s",dfnames[0]);
        for (int d = 1; d < cols.length; ++d) {
            stream.printf(",%s",dfnames[d]);
        }

        // print the label column headers 
        List<String> allHeads = df.getHeaders();
        for (int l = 0; l < labelCols.length; ++l) {
            stream.printf(",%s", allHeads.get(labelCols[l]));
        }
        stream.println();

        // print the data
        for (int r = 0; r < df.getSize(); ++r) {
            ArrayList<String> row = df.getRow(r);

            stream.print(row.get(cols[0]));
            for (int d = 1; d < cols.length; ++d) {
                stream.printf(",%s", row.get(cols[d]));
            }

            for (int l = 0; l < labelCols.length; ++l) {
                stream.printf(",%s", row.get(labelCols[l]));
            }
            stream.println();
        }

        stream.close();
    }

    public String columnNameFromAttributes(int c) {
        StringBuilder builder = new StringBuilder();
        boolean first = true;
        for (String attr : attributeMap.keySet()) {
            String[] values = attributeMap.get(attr);
            if (!first) {
                builder.append("_");
            }
            builder.append(values[c]);
            first = false;
        }
        return builder.toString();
    }

    public jsat.DataSet toJsatDataSet() {
        List<jsat.classifiers.DataPoint> points = new ArrayList<>();
        double[][] x = this.getAsDouble();
        for (int i = 0; i < x.length; ++i) {
            DenseVector v = new DenseVector(x[i]);
            points.add(new jsat.classifiers.DataPoint(v));
        }
        return new SimpleDataSet(points);
    }

    public final void fromXML(Element ele) {
        title = ele.getAttributeValue("title");

        String s = ele.getAttributeValue("columns");
        String[] tokens = s.split(",");
        cols = new int[tokens.length];
        for (int i = 0; i < cols.length; ++i) {
            cols[i] = Integer.valueOf(tokens[i]);
        }

        attributeMap = new TreeMap<>();
        for (Element attrEle : ele.getChildren("Attribute")) {
            String key = attrEle.getAttributeValue("key");
            s = attrEle.getAttributeValue("values");
            String[] values = s.split(",");
            attributeMap.put(key, values);
        }
    }

    public Element toXML() {
        Element ret = new Element("DataSet");
        ret.setAttribute("title", title);

        StringBuilder builder = new StringBuilder();
        builder.append(Integer.toString(cols[0]));
        for (int i = 1; i < cols.length; ++i) {
            builder.append(",");
            builder.append(Integer.toString(cols[i]));
        }
        ret.setAttribute("columns", builder.toString());

        for (String key : attributeMap.keySet()) {
            Element attEle = new Element("Attribute");
            attEle.setAttribute("key", key);
            String[] values = attributeMap.get(key);
            builder = new StringBuilder();
            builder.append(values[0]);
            for (int i = 1; i < values.length; ++i) {
                builder.append(",");
                builder.append(values[i]);
            }
            attEle.setAttribute("values", builder.toString());
            ret.addContent(attEle);
        }
        return ret;
    }

    // the input dataSetColumn is the index into the array of columns , not the column in the dataframe
    // so it goes from 0 to the number of columns-1
    public boolean setAttribute(String key, int dataSetColumn, String attributeValue) {
        String[] attrs = attributeMap.get(key);

        if (attrs != null && dataSetColumn >= 0 && dataSetColumn < cols.length) {
            attrs[dataSetColumn] = attributeValue;
            return true;
        }
        return false;
    }

    public String getAttribute(String key, int dataSetColumn) {
        String[] attrs = attributeMap.get(key);
        if (attrs != null) {
            return attrs[dataSetColumn];
        }
        return null;
    }

    public boolean columnPartOfSet(int column) {
        for (int c : cols) {
            if (c == column) {
                return true;
            }
        }
        return false;
    }

    public DataFrame getDataFrame() {
        return df;
    }

    public int[] getColumns() {
        return this.cols;
    }

    public void setColumns(int[] columns) {
        this.cols = columns;
    }

    public Set<String> getAttributes() {
        return this.attributeMap.keySet();
    }

    public double[][] getAsDouble() {
        return df.getAsDouble(cols);
    }

    // get data for a set of rows
    public double[][] getDataForRows(int[] rows) {
        double[][] ret = new double[rows.length][];
        for (int i = 0; i < rows.length; ++i) {
            ret[i] = getRow(rows[i]);
        }
        return ret;
    }

    // get an entire row of data
    public double[] getRow(int row) {
        List<String> rowData = df.getRow(row);
        double[] ret = new double[cols.length];
        for (int i = 0; i < ret.length; ++i) {
            ret[i] = Double.valueOf(rowData.get(cols[i]));
        }
        return ret;
    }

    // get a single data item given row and column
    public double getData(int row, int i) {
        List<String> rowData = df.getRow(row);
        return Double.valueOf(rowData.get(cols[i]));
    }

    @Override
    public String toString() {
        return title;
    }

    public String[] getColumnNames() {
        List<String> heads = this.df.getHeaders();
        String[] ret = new String[cols.length];
        for (int i = 0; i < ret.length; ++i) {
            ret[i] = heads.get(cols[i]);
        }
        return ret;
    }

    // returns the data with the rows clustered 
    public double[][] getClustered(int[][] clusters) {
        ArrayList<double[]> list = new ArrayList<>();

        for (int cl = 0; cl < clusters.length; ++cl) {
            int[] cluster = clusters[cl];
            for (int r = 0; r < cluster.length; ++r) {
                list.add(getRow(cluster[r]));
            }
        }

        double[][] x = new double[list.size()][];
        int i = 0;
        for (double[] v : list) {
            x[i] = v;
            ++i;
        }
        return x;
    }

    // return the transpose of the data clustered by rows
    public double[][] getTransposedClustered(int[][] clusters) {

        double[][] x = getClustered(clusters);

        int n = x[0].length;

        double[][] ret = new double[n][];
        for (int i = 0; i < ret.length; ++i) {
            ret[i] = new double[x.length];
        }

        for (int i = 0; i < x.length; ++i) {
            double[] v = x[i];
            for (int j = 0; j < v.length; ++j) {
                ret[j][i] = v[j];
            }
        }

        return ret;
    }

    // characterize a set of rows by the attributes with the maximum mean of the values in the given rows
    public String characterizeByMeans(int[] rows, double thresh, String attr) {
        double[][] x = this.getDataForRows(rows);

        // compute the mean of each column
        double[] means = new double[cols.length];
        for (int c = 0; c < means.length; ++c) {
            double sum = 0.0;
            for (int r = 0; r < x.length; ++r) {
                sum = sum + x[r][c];
            }
            means[c] = sum / x.length;
        }

        // sort the columns by the means
        TreeMap<Double, Integer> sortmap = new TreeMap<>();
        for (int c = 0; c < means.length; ++c) {
            sortmap.put(means[c], c);
        }

        double mu = StatUtils.mean(means);
        double sd = Math.sqrt(StatUtils.variance(means, mu));

        // find columns at least thresh*sd above the mu
        ArrayList<Integer> colsToUse = new ArrayList<>();
        for (double mean : sortmap.descendingKeySet()) {
            if (mean - mu >= thresh * sd) {
                colsToUse.add(sortmap.get(mean));
            } else {
                break;
            }
        }

        StringBuilder builder = new StringBuilder();
        if (!colsToUse.isEmpty()) {
            builder.append(this.getAttribute(attr, colsToUse.get(0)));
            for (int i = 1; i < colsToUse.size(); ++i) {
                builder.append("/");
                builder.append(this.getAttribute(attr, colsToUse.get(i)));
            }
        }

        return builder.toString();
    }
}
