/*
 * To change this license header, choose License Headers in Project Properties.
 * To change this template file, choose Tools | Templates
 * and open the template in the editor.
 */
package org.rhwlab.encode.ChipSeq.peaks;

import com.apporiented.algorithm.clustering.AverageLinkageStrategy;
import com.apporiented.algorithm.clustering.Cluster;
import com.apporiented.algorithm.clustering.ClusteringAlgorithm;
import com.apporiented.algorithm.clustering.DefaultClusteringAlgorithm;
import com.apporiented.algorithm.clustering.visualization.DendrogramPanel;
import java.awt.Color;
import java.awt.Dimension;
import java.util.ArrayList;
import java.util.List;
import java.util.TreeMap;
import javax.swing.JFrame;
import javax.swing.JScrollPane;
import org.apache.commons.math3.linear.Array2DRowRealMatrix;
import org.apache.commons.math3.ml.distance.EuclideanDistance;
import org.rhwlab.heatmap.MappableHeatMap;
import org.rhwlab.heatmap.SaveableFrame;
import org.rhwlab.imaging.binarytree.Gradient;

/**
 *
 * @author gevirl
 */
public class EnrichmentHeatmap {
    static public void main(String[] args)throws Exception {
        double minP = 0.01;
        double thresh = -Math.log10(minP);
        TreeMap<String,List<ClusterEnrichment>> map = 
                ClusterEnrichment.readFiles("/net/waterston/vol2/home/gevirl/FACS/enrichment/gene", "geneEnrich"); //indexed by filename
        
        // organize data by tf
        TreeMap<String,TreeMap<String,ClusterEnrichment>> byTF = new TreeMap<>();  // tf,filename -> value
        for (String fileName : map.keySet()){
            for (ClusterEnrichment enrich : map.get(fileName)){
                String tf = enrich.tf;
                TreeMap<String,ClusterEnrichment> byFile = byTF.get(tf);
                if (byFile == null){
                    byFile = new TreeMap<>();
                    byTF.put(tf,byFile);
                }
                byFile.put(fileName, enrich);
            }
        }
        
        // remove any tf that do not have at least one pVal < threshold
        List<String> toRemove = new ArrayList<>();
        for (String tf : byTF.keySet()){
            if (tf.equals("ceh-22")){
                int jkasdfusdhf=0;
            }
            boolean keep = false;
            for (ClusterEnrichment enrich : byTF.get(tf).values()){
                if (enrich.pVal < minP){
                    keep = true;
                    break;
                }
            }
            if (!keep){
                toRemove.add(tf);
            }
        }
        for (String tf : toRemove){
            byTF.remove(tf);
        }
        
        int n = byTF.size();
        String[] tfNames = byTF.keySet().toArray(new String[0]);
        
        // put the data into a matrix - tf in rows
        double[][] m = new double[n][];
        int row=0;
        for (String tf : byTF.keySet()){
        
            TreeMap<String,ClusterEnrichment> byFile = byTF.get(tf);
            m[row] = new double[byFile.size()];
            int col = 0;
            for (String fileName : byFile.keySet()){
                m[row][col] = byFile.get(fileName).logP;
                ++col;
            }
            ++row;
        }
        
        // form a distance matrix
        EuclideanDistance dist = new EuclideanDistance();
        double[][] d = new double[n][];
        for (int i=0 ; i<d.length ; ++i){
            d[i] = new double[n];
            d[i][i] = 0.0;
        }
        for (int r=0 ; r<n ; ++r){
            for (int c =r+1 ; c<n ; ++c){
                double ed = dist.compute(m[r], m[c]);
                d[r][c] = ed;
                d[c][r] = ed;
            }
        }

        ClusteringAlgorithm alg = new DefaultClusteringAlgorithm();
        Cluster cluster = alg.performClustering(d, tfNames, new AverageLinkageStrategy());
        List<String> leaves = cluster.getLeafNames();
        
        // reorder the data
        double largest = 0.0;
        double smallest = 0.0;
        double[][] z = new double[m.length][];
        for (int i=0 ; i<leaves.size() ; ++i){
            TreeMap<String,ClusterEnrichment> byFile = byTF.get(leaves.get(i));
            z[i] = new double[byFile.size()];
            int j = 0;
            for (String fileName : byFile.keySet()){
                z[i][j] = byFile.get(fileName).logP;
                if (z[i][j] < smallest){
                    smallest = z[i][j];
                }
                if (z[i][j] > largest){
                    largest = z[i][j];
                }
                ++j;
            }
        }
        
        double[][] x = new double[1][];
        int nn = (int) ((largest - smallest));
        x[0] = new double[nn];
        for (int i=0 ; i<nn ; ++i){
            x[0][i] = largest-i;
            if (x[0][i]<0.0){
                --x[0][i];
            }
        }
        TreeMap<Double, String> scaleLabels = new TreeMap<>();
        for (int i=0 ; i<nn ; ++i){
            int v = (int)x[0][i];
            if (Math.abs(v)%10 == 0){
                scaleLabels.put((double)i, Integer.toString(v));
            }
        }
        
       
        double pos = 0.5;
        TreeMap<Double, String> yLabels = new TreeMap<>();
        for (int i=0 ; i<z.length ; ++i){
if (i%2!=0) System.out.println(leaves.get(i));                
            yLabels.put(pos, leaves.get(i));
            ++pos;
        }
        
        String[] seriesNames= {"ceh-32","cnd-1","end-1","hlh-1","nhr-25","pha-4","tbx-37"};
        double r = .5;
        double s = 2.5;
        TreeMap<Double, String> xLabels = new TreeMap<>();
        TreeMap<Double,String> seriesLabels = new TreeMap<>();
        for (int i = 0; i < 7; ++i) {
            String series = seriesNames[i];
            seriesLabels.put(s, series);
            for (int j = 0; j < 5; ++j) {
                xLabels.put(r, String.format("T%d", j));
                ++r;
                
            }
            s = s + 5;
        }   
        
         z = new Array2DRowRealMatrix(z).transpose().getData();
         
        Color[][] gradients = new Color[3][];
        Color[] negColors = new Color[3];
        negColors[0] = Color.MAGENTA;
        negColors[1] = new Color(125,0,125);
        negColors[2] = new Color(0,0,255,100);
        gradients[0] = Gradient.createMultiGradient(negColors,100);
        gradients[1] = new Color[1];
        gradients[1][0] = Color.LIGHT_GRAY;
        Color[] posColors = new Color[5];
        posColors[0] = new Color(0,255,0,50);
        posColors[1] = new Color(0,255,0);
        posColors[2] = new Color(255,255,0,100);
        posColors[3] = new Color(255,255,0,150);        
        posColors[4] = new Color(255,255,0);
        
        gradients[2] = Gradient.createMultiGradient(posColors,100);
        
        double[] bounds = new double[2];
        bounds[0] = -thresh;
        bounds[1] = thresh;
        MappableHeatMap zheatMap = new MappableHeatMap(z, true, gradients,bounds);
        zheatMap.setDrawYTicks(true);
        zheatMap.setYLabels(yLabels);
        zheatMap.setDrawXTicks(true);
        zheatMap.setXLabels(xLabels);
        zheatMap.setXLabels(seriesLabels);        
        zheatMap.setPreferredSize(new Dimension(600,1500));
        
        MappableHeatMap xheatMap = new MappableHeatMap(x, true, gradients,bounds);
        xheatMap.setDrawYTicks(true);
        xheatMap.setYLabels(scaleLabels);        
//        zheatMap.setPreferredSize(new Dimension(100,400));
        SaveableFrame xheatChartFrame = new SaveableFrame();
        xheatChartFrame.setContentPane(xheatMap);
        xheatChartFrame.setSize(200,700);
        xheatChartFrame.setVisible(true);
        
        SaveableFrame zheatChartFrame = new SaveableFrame();
        zheatChartFrame.setTitle("Both");
        zheatChartFrame.setContentPane(zheatMap);
        zheatChartFrame.pack();
        zheatChartFrame.setSize(1000,1100);
        zheatChartFrame.setVisible(true);
        zheatChartFrame.setDefaultCloseOperation(JFrame.DISPOSE_ON_CLOSE);

/*
        
        x = new Array2DRowRealMatrix(x).transpose().getData();
        Color[] colors = {Color.BLUE, Color.GREEN, Color.YELLOW};
        Color[][] gradients2= new Color[2][];
        gradients2[0] = new Color[1];
        gradients2[0][0] = Color.DARK_GRAY;
        gradients2[1] = Gradient.createMultiGradient(colors, 100);
        double[] bounds1 = new double[1];
        bounds1[0] = thresh;
        MappableHeatMap heatMap = new MappableHeatMap(x, true, gradients2,bounds1);
        heatMap.setDrawYTicks(true);
        heatMap.setYLabels(yLabels);
        heatMap.setPreferredSize(new Dimension(500,2000));
        JScrollPane heatScroll = new JScrollPane(heatMap);
        
        JFrame heatChartFrame = new JFrame();
        heatChartFrame.setTitle("Positive");
        heatChartFrame.getContentPane().add(heatScroll);
        heatChartFrame.pack();
        heatChartFrame.setSize(500, 900);
        heatChartFrame.setVisible(true);
        heatChartFrame.setDefaultCloseOperation(JFrame.DISPOSE_ON_CLOSE);
        
        y = new Array2DRowRealMatrix(y).transpose().getData();

        MappableHeatMap yheatMap = new MappableHeatMap(y, true, gradients2,bounds1);
        yheatMap.setDrawYTicks(true);
        yheatMap.setYLabels(yLabels);
        yheatMap.setPreferredSize(new Dimension(500,2000));
        JScrollPane yheatScroll = new JScrollPane(yheatMap);
        
        JFrame yheatChartFrame = new JFrame();
        yheatChartFrame.setTitle("Negative");
        yheatChartFrame.getContentPane().add(yheatScroll);
        yheatChartFrame.pack();
        yheatChartFrame.setSize(500, 900);
        yheatChartFrame.setVisible(true);
        yheatChartFrame.setDefaultCloseOperation(JFrame.DISPOSE_ON_CLOSE);        
        
 */       
        DendrogramPanel dp = new DendrogramPanel();
        dp.setModel(cluster);
        dp.setPreferredSize(new Dimension(1000,1500));
        JFrame frame = new JFrame();
        JScrollPane scroll = new JScrollPane(dp);
        frame.setContentPane(scroll);
        
        frame.pack();
        frame.setSize(1000, 1200);
        frame.setVisible(true);
        frame.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
        int asdhfuishd=0;
    }
}
