/*
 * To change this license header, choose License Headers in Project Properties.
 * To change this template file, choose Tools | Templates
 * and open the template in the editor.
 */
package org.rhwlab.LMS.dataframe.investigator;

import java.awt.event.ActionListener;
import java.io.File;
import java.io.PrintStream;
import java.util.ArrayList;
import java.util.List;
import org.apache.commons.math3.linear.OpenMapRealMatrix;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.ml.distance.DistanceMeasure;
import org.jdom2.Element;
import org.rhwlab.LMS.dataframe.DataSet;
import org.rhwlab.LMS.dataframe.investigator.algorithm.Silhouette;
import org.rhwlab.LMS.dataframe.investigator.measure.NamedEuclidean;

/**
 *
 * @author gevirl
 */
public class Clustering {

    RealMatrix mat;  // probability the row belongs to cluster - columns are clusters ; rows sum to one
    String title;
    int[] order;
    double probThresh = Double.MIN_VALUE;
    List<Silhouette> silhs;

    public Clustering(String title, RealMatrix m) {
        this.mat = m;
        this.title = title;
        initOrder();
        int shfduishdfuih=0;
    }

    //indexes = the list of records for each cluster - not a square matrix
    public Clustering(String title, int[][] indexes) {
        this.title = title;

        // find the maximum record number
        int max = 0;
        for (int i = 0; i < indexes.length; ++i) {
            int[] rows = indexes[i];
            for (int j = 0; j < rows.length; ++j) {
                if (rows[j] > max) {
                    max = rows[j];
                }
            }
        }
        mat = new OpenMapRealMatrix(max + 1, indexes.length);

        for (int i = 0; i < indexes.length; ++i) {
            int[] rows = indexes[i];
            for (int j = 0; j < rows.length; ++j) {
                mat.setEntry(rows[j], i, 1.0);
            }
        }
        initOrder();
    }

    public Clustering(Element ele) {
        this.title = ele.getAttributeValue("title");
        List<Element> clElements = ele.getChildren("Cluster");

        int nRecords = 0;
        if (ele.getAttributeValue("records") == null) {
            // find the maximum record
            for (Element clEle : clElements) {
                String[] tokens = clEle.getTextNormalize().split(",");
                for (String token : tokens) {
                    int i = Integer.valueOf(token.split(":")[0]);
                    if (i > nRecords) {
                        nRecords = i;
                    }
                }
            }
            ++nRecords;
        } else {
            nRecords = Integer.valueOf(ele.getAttributeValue("records"));
        }
        mat = new OpenMapRealMatrix(nRecords, clElements.size());

        for (Element clEle : clElements) {
            int label = Integer.valueOf(clEle.getAttributeValue("label"));
            String[] tokens = clEle.getTextNormalize().split(",");
            for (String token : tokens) {
                String[] vals = token.split(":");
                int rec = Integer.valueOf(vals[0]);
                double p = 1.0;
                if (vals.length == 2) {
                    p = Double.valueOf(vals[1]);
                }
                if (p > 0.0) {
                    mat.setEntry(rec, label, p);
                }
            }
        }
        initOrder();

        silhs = new ArrayList<>();
        for (Element resEle : ele.getChildren("Silhouette")) {
            silhs.add(new Silhouette(resEle));
        }
        int ioei = 0;
    }

    // clusterlabels - the cluster number for each data record
    public Clustering(String title, int[] clusterLabels) {
        this.title = title;

        // how many clusters are there?
        int nClusters = 0;
        for (int i = 0; i < clusterLabels.length; ++i) {
            if (clusterLabels[i] > nClusters) {
                nClusters = clusterLabels[i];
            }
        }

        mat = new OpenMapRealMatrix(clusterLabels.length, nClusters + 1);

        for (int i = 0; i < clusterLabels.length; ++i) {
            mat.setEntry(i, clusterLabels[i], 1.0);
        }
        initOrder();
    }

    private void initOrder() {
        order = new int[mat.getColumnDimension()];
        for (int i = 0; i < order.length; ++i) {
            order[i] = i;
        }
    }

    @Override
    public String toString() {
        return title;
    }

    public Element toXML() {
        Element ret = new Element("Clustering");
        ret.setAttribute("title", title);

        for (int j = 0; j < mat.getColumnDimension(); ++j) {
            int cl = order[j];
            Element clEle = new Element("Cluster");
            clEle.setAttribute("label", Integer.toString(cl));
            StringBuilder builder = new StringBuilder();
            boolean first = true;
            for (int i = 0; i < mat.getRowDimension(); ++i) {
                double v = mat.getEntry(i, cl);
                if (v != 0.0) {
                    if (!first) {
                        builder.append(",");
                    }
                    builder.append(i);
                    builder.append(":");
                    builder.append(String.format("%.3f", v));
                    first = false;
                }
            }
            clEle.addContent(builder.toString());
            ret.addContent(clEle);
        }
        if (silhs != null) {
            for (Silhouette silh : silhs) {
                ret.addContent(silh.toXML());
            }
        }
        return ret;
    }

    public void saveAsCSV(File file, String[] genes) throws Exception {
        PrintStream stream = new PrintStream(file);

        // print the headers
        stream.print("Gene,BestCluster,BestProbability");
        for (int c = 0; c < mat.getColumnDimension(); ++c) {
            stream.printf(",Cluster%d", c + 1);
        }
        stream.println();

        for (int r = 0; r < genes.length; ++r) {

            // find the best cluster for this gene
            double bestP = 0.0;
            int bestC = -1;
            for (int c = 0; c < mat.getColumnDimension(); ++c) {
                double p = mat.getEntry(r, c);
                if (p > bestP) {
                    bestP = p;
                    bestC = c;
                }
            }
            stream.printf("%s,%d,%.4f", genes[r], bestC + 1, bestP);

            // print all the cluster probabilities for the gene
            for (int c = 0; c < mat.getColumnDimension(); ++c) {
                double p = mat.getEntry(r, c);
                stream.printf(",%.4f", p);
            }
            stream.println();
        }
        stream.close();
    }

    public int bestClusterForRow(int r) {
        boolean used = false;
        double bestP = 0.0;  // best column
        int bestC = -1;   // best probability
        for (int c = 0; c < mat.getColumnDimension(); ++c) {
            double p = mat.getEntry(r, c);
            if (p >= probThresh && p > bestP) {
                used = true;
                bestP = p;
                bestC = c;

            }
        }
        if (!used) {
            return mat.getColumnDimension();
        } else {
            return bestC;
        }
    }

    public int[][] getClusters() {
        return getClusters(true);
    }

    // get the clusters based on current order and threshold,
    // last custer contains unclustered records 
    // row dimension of the returned int arrya is the number of clusters (+1 if including unclustered genes)
    // the columns of the returned array contain the row numbers of the genes in the cluster, and vary by cluster (# of genes per cluster varies)
    public int[][] getClusters(boolean uncluster) {
        List<Integer> unclustered = new ArrayList<>();
        ArrayList<Integer>[] rowLists = new ArrayList[mat.getColumnDimension()];
        for (int i = 0; i < rowLists.length; ++i) {
            rowLists[i] = new ArrayList<>();
        }
        
        int d = mat.getColumnDimension();
        if (uncluster) {
            ++d;
        }
        int[][] ret = new int[d][];
        
        // find the best cluster for each gene
        for (int r = 0; r < mat.getRowDimension(); ++r) {
            int bestC = bestClusterForRow(r);
            if (bestC==mat.getColumnDimension()) {
                unclustered.add(r);
            } else {
                rowLists[bestC].add(r);
            }
        }

        // reorder the clusters
        for (int cluster = 0; cluster < mat.getColumnDimension(); ++cluster) {
            int col = order[cluster];
            int n = rowLists[col].size();
            int[] rows = new int[n];
            int i = 0;
            for (Integer row : rowLists[col]) {
                rows[i] = row;
                ++i;
            }
            ret[cluster] = rows;
        }

        // make the unclustered list 
        if (uncluster) {
            int[] uncl = new int[unclustered.size()];
            int i = 0;
            for (Integer row : unclustered) {
                uncl[i] = row;
                ++i;
            }
            ret[mat.getColumnDimension()] = uncl;
        }
        return ret;
    }

    private int getMaxColumn(int row) {
        double maxP = 0.0;
        int ret = -1;
        for (int c = 0; c < mat.getColumnDimension(); ++c) {
            double p = mat.getEntry(row, c);
            if (maxP < p) {
                maxP = p;
                ret = c;
            }
        }
        return ret;
    }

    // reorganize the order of the clusters based on closeness in the given data set
    public void reorganizeClusters(DataSet ds, DistanceMeasure measure) {
        initOrder();
        int[][] clusterIndexes = this.getClusters();
        double[][] centers = this.centroids(ds, clusterIndexes);
        ArrayList<ClusterNode> nodes = new ArrayList<>();
        for (int i = 0; i < mat.getColumnDimension(); ++i) {
            nodes.add(new ClusterNode(centers[i], i));
        }
        ClusterNode root = toTree(nodes, measure);

        List<ClusterNode> leaves = root.getLeaves();

        // reorder the clusters
        order = new int[mat.getColumnDimension()];
        for (int i = 0; i < order.length; ++i) {
            order[i] = leaves.get(i).leaf;
        }
    }

    // the centroids of the clusters, ie the mean of the data in the cluster
    public double[][] centroids(DataSet ds, int[][] clusterIndexes) {
        double[][] ret = new double[clusterIndexes.length][];
        double[][] x = ds.getAsDouble();
        int n = ds.getColumns().length;  // dimension of the centroid

        for (int cl = 0; cl < clusterIndexes.length; ++cl) {
            int[] idxs = clusterIndexes[cl];
            double[] centroid = new double[n];

            for (int i = 0; i < idxs.length; ++i) {
                int row = idxs[i];
                for (int col = 0; col < n; ++col) {
                    centroid[col] = centroid[col] + x[row][col];
                }
            }
            for (int col = 0; col < n; ++col) {
                centroid[col] = centroid[col] / idxs.length;
            }
            ret[cl] = centroid;
        }
        return ret;
    }

    public ClusterNode toTree(List<ClusterNode> nodes, DistanceMeasure measure) {
        if (nodes.size() == 1) {
            return nodes.get(0);
        }

        // find the closest pair of nodes
        int[] pair = ClusterNode.closestPair(nodes, measure);

        // combine the closest
        ClusterNode combined = new ClusterNode(nodes.get(pair[0]), nodes.get(pair[1]));

        // remove the two closest nodes from the list
        nodes.remove(pair[1]);
        nodes.remove(pair[0]);

        // add the new node to the list
        nodes.add(combined);

        return toTree(nodes, measure);
    }

    public void setProbabilityThreshold(double p) {
        if (p <= 0.0) {
            this.probThresh = Double.MIN_VALUE;
        } else if (p > 1.0) {
            this.probThresh = 1.0;
        } else {
            this.probThresh = p;
        }
    }

    public double getProbabilityThreshold() {
        return this.probThresh;
    }


    public void formSilhouette(DataSet ds, ActionListener listen) {
        Silhouette sil = new Silhouette(ds, this, listen);

        Thread thread = new Thread(sil);
        thread.start();
        if (silhs == null) {
            silhs = new ArrayList<>();
        }
        silhs.add(sil);
    }

    public List<Silhouette> getSilhouettes() {
        return this.silhs;
    }
}
