/*
 * To change this license header, choose License Headers in Project Properties.
 * To change this template file, choose Tools | Templates
 * and open the template in the editor.
 */
package org.rhwlab.chipseq;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.PrintStream;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.TreeMap;
import org.apache.commons.math3.ml.clustering.Cluster;
import org.apache.commons.math3.ml.clustering.DBSCANClusterer;
import org.rhwlab.formats.GFF3;
import org.rhwlab.formats.GFF3.Entry;

/**
 *
 * @author gevirl
 */
public class ModernBed {
    static String[] clusterColors ={"0,0,255","0,255,0","255,0,0"};
    int[] clusterSizes = {10,100};
    File file;
    String gene;
    String stage;
    String strain;
    TreeMap<String,List<BindSite>> records; // bed records stored by chromosome
    
    public ModernBed(File file)throws Exception {
        this.file = file;
        init();
    }
    // read in the bed file
    final public void init() throws Exception {
        
        records = new TreeMap<>();
        BufferedReader reader = new BufferedReader(new FileReader(this.file));
        String line = reader.readLine();
        while (line != null){
            String[] tokens = line.split(" |\t|,");
            
            // fix the chromosome to one of (I,II,II,IV,V,X,MtDNA)
            if (tokens[0].startsWith("chr")){
                tokens[0] = tokens[0].substring(3);
                if (tokens[0].equals("M")){
                    tokens[0] = "MtDNA";
                }
            }
            List<BindSite> recordList = records.get(tokens[0]);
            if (recordList == null){
                recordList = new ArrayList<>();
                records.put(tokens[0], recordList);
            }
            recordList.add(new BindSite(tokens));
            line = reader.readLine();
        }
        reader.close();        
    }
    public void sortSites(){
        for (List<BindSite> siteList : records.values()){
            siteList.sort(new BindSiteComparator());
        }
    }
    // make a bed file in which the peaks in a cluster are colored the same
    // cluster id is the last column in the input bed
    // output is standard bed format for genome browser
    public void clusterColorBed(PrintStream stream,String label,String desc){
        stream.printf("track name=\"%s\" description=\"%s\" visibility=2 itemRgb=\"On\"\n",label,desc);
        for (List<BindSite> recordList : this.records.values()){
            ArrayList<BindSite> bsInCluster = new ArrayList<>();
            int currentCluster = -1;
            for (BindSite bs  : recordList){
                String[] record = bs.tokens;
                String chr = record[0];
                if (chr.startsWith("M")){
                    continue;   // skip mitochondrial
                }
                int cluster = Integer.valueOf(record[record.length-1]);
                if (cluster == currentCluster){
                    bsInCluster.add(bs);
                } 
                else {
                    // assign a color to the previous cluster based on its size
                    int n = bsInCluster.size();
                    int colorIndex = clusterSizes.length;
                    if (n > 0){
                        for (int i=0 ; i<clusterSizes.length ; ++i){
                            if (n <=clusterSizes[i]){
                                colorIndex = i;
                                break;
                            }
                        }
                               
                    }
                    // write out the previous cluster
                    for (BindSite clbs : bsInCluster){
                        record = clbs.tokens;
                        stream.printf("chr%s\t%s\t%s\t%s\t0\t%s\t%s\t%s\t%s\n",
                                chr,record[1],record[2],record[14],record[5],record[1],record[2],clusterColors[colorIndex]);                        
                    }
                    
                    // start a new cluster
                    bsInCluster.clear();
                    currentCluster = cluster;
                    bsInCluster.add(bs);
                }
               

            }
        }
    }
    // expand the bed file to match the format Carlos created
    public void enhance(){
        // combine all the chromosomes and sort by score
        ArrayList<BindSite> all = new ArrayList<>();
        for (List<BindSite> recordList : records.values()){
            all.addAll(recordList);
            recordList.clear();
        }
        all.sort(new ScoreComparator());
        
        int i=0;
        for (BindSite bs : all){
            String[] tokens = bs.tokens;
            ++i;
            int del = Integer.valueOf(tokens[2])-Integer.valueOf(tokens[1]);
            String id = String.format("ce_%s_%s_%s_WA",strain,gene,stage);
            tokens[3] = String.format("%s:P%d",id,i);
            
            String[] outTokens = new String[tokens.length+8];
            for (int j=0 ; j<tokens.length ; ++j){
                outTokens[j] = tokens[j];
            }
            int index = tokens.length;
            outTokens[index] = Integer.toString(del);
            ++index;
            outTokens[index] = id;
            ++index;
            outTokens[index] = "ce";
            ++index;
            outTokens[index] = strain;
            ++index;
            outTokens[index] = gene;
            ++index;
            outTokens[index] = stage;
            ++index;
            outTokens[index] = "WA";
            ++index;
            outTokens[index] = "stn";
            List<BindSite> recordList = records.get(tokens[0]);
            recordList.add(new BindSite(outTokens));
        }
    }
    // add gene targets to the bed file
    public void addTargets(GFF3 gff){
        for (List<BindSite> recordList : records.values()){
            for (int rec=0 ; rec<recordList.size() ; ++rec){
                String[] tokens = recordList.get(rec).tokens;
                String[] outTokens = new String[tokens.length+8];
                for (int j=0 ; j<tokens.length ; ++j){
                    outTokens[j] = tokens[j];
                } 

                String chr = tokens[0];
                BindSite bs = new BindSite(tokens);
                int loc = (int)bs.getCoord();

                int index = tokens.length;

                String lowerPlus = gff.lowerPlus(chr, loc);
                int lowerPlusDist = -1;
                if (lowerPlus == null){
                    lowerPlus = "na";
                } else {
                    Entry entry = gff.getEntry(lowerPlus);
                    lowerPlusDist = loc - entry.end;
                } 
                outTokens[index++] = lowerPlus;
                outTokens[index++] = Integer.toString(lowerPlusDist);

                String lowerMinus = gff.lowerMinus(chr, loc);
                int lowerMinusDist = -1;
                if (lowerMinus == null){
                    lowerMinus = "na";
                } else {
                    Entry entry = gff.getEntry(lowerMinus);
                    lowerMinusDist = loc - entry.end;                    
                }
                outTokens[index++] = lowerMinus;
                outTokens[index++] = Integer.toString(lowerMinusDist); 

                String upperPlus = gff.upperPlus(chr, loc);
                int upperPlusDist = -1;
                if (upperPlus == null){
                    upperPlus = "na";
                }else {
                    Entry entry = gff.getEntry(upperPlus);
                    upperPlusDist = entry.start - loc;                    
                }
                outTokens[index++] = upperPlus;
                outTokens[index++] = Integer.toString(upperPlusDist);            

                String upperMinus = gff.upperMinus(chr, loc);
                int upperMinusDist=-1;
                if (upperMinus == null){
                    upperMinus = "na";
                }  else {
                    Entry entry = gff.getEntry(upperMinus);
                    upperMinusDist = entry.start - loc;                     
                }
                outTokens[index++] = upperMinus;
                outTokens[index++] = Integer.toString(upperMinusDist);             

                recordList.set(rec,new BindSite(outTokens));
            }
        }
    }

    // append this bed file to another
    // does not merge the chromosome records together
    public void appendTo(File dest)throws Exception {
        PrintWriter writer = new PrintWriter(new FileWriter(dest,true));
        for (List<BindSite> recordList : records.values()){
            for (BindSite bs  : recordList){
                String[] tokens = bs.tokens;
                for (int i=0 ; i<tokens.length-1 ; ++i){
                     writer.printf("%s\t", tokens[i]);     
                }
                writer.printf("%s\n", tokens[tokens.length-1]);
            }
        }
        writer.close();

    }
    public void parseFilename(){
        String name = file.getName();
        String[] tokens = name.split("_");
        this.gene=tokens[7];
        this.stage = tokens[10];
        this.strain = tokens[8];
    }
    
    // report the bind sites in the given bed that have not been put into a cluster of this clustered bed file
    // can only be run on clustered bed file
    // reported bibd sites are in color format for browser display
    public void notClustered(ModernBed bed,PrintStream stream){
        for (String chr : bed.records.keySet()){
            List<BindSite> bedList = bed.records.get(chr);
            List<BindSite> clusteredList = this.records.get(chr);
            for (BindSite site : bedList){
                if (!clusteredList.contains(site)){
                    site.shortPrint(stream);
                }
            }
            
        }
    }
    // cluster a list of sites on a chromosome and remove sites from the list that get clustered
    // the list ends up having the outliers
    // uses a density clusterer
    public List<Cluster> cluster(List<BindSite> sites,double episilon,int minPts){
 //       DBSCANClusterer clusterer = new DBSCANClusterer(80.0,3);
        DBSCANClusterer clusterer = new DBSCANClusterer(episilon,minPts);
        List<Cluster> ret = clusterer.cluster(sites);

        return ret;
    }
    // cluster all the chromosomes separately and report results
    public void cluster(PrintStream stream,double episilon,int minPts){
        for (List<BindSite> list : records.values()){
            List<Cluster> clusterList = cluster(list,episilon,minPts);
            for (Cluster cluster : clusterList){
                List<BindSite> bsList = cluster.getPoints();
                bsList.sort(new BindSiteComparator());
            }
            clusterList.sort(new ClusterComparator());
            reportClusters(stream,clusterList);
        }
    }
    public void reportClusters(PrintStream stream,List<Cluster> clusterList){
        int c =0;
        for (Cluster cl : clusterList){
            ++c;
            List<BindSite> bsList = cl.getPoints();
            for (BindSite bs : bsList){
                for (int i=0 ; i<bs.tokens.length ; ++i){
                    stream.printf("%s\t",bs.tokens[i]);
                }
                stream.printf("%d\n", c);
            }
        }
    }
    // cluster the bindsites by simnple algorithm
    // bindsites list must be sorted
    public List<Cluster> cluster(List<BindSite> sites,double delta){
        ArrayList<Cluster> ret = new ArrayList<>();

        
        Cluster cluster = new Cluster();
        ret.add(cluster);
        
        double previous = sites.get(0).getCoord();
        for (BindSite site : sites){
            if (site.getCoord()-previous <= delta){
                // add to current cluster
                cluster.addPoint(site);
            }
            else {
                // start a new cluster
                cluster = new Cluster();
                ret.add(cluster);
                cluster.addPoint(site);
            }
            previous = site.getCoord();
        }
        return ret;
    }
    // cluster all the chromosomes by simple algorithm
    // bind sites must be sorted
    public Map<String,List<Cluster>> cluster(double delta){
        TreeMap<String,List<Cluster>> ret = new TreeMap<>();
        for (String chr : records.keySet()){
            List<BindSite> siteList = records.get(chr);
            List<Cluster> clusters = cluster(siteList,delta);
            ret.put(chr, clusters);
        }
        return ret;
    }
    // args[0] - input bed
    // args[1] - delta min
    // args[2] - delta max
    // args[3] - increment
    static public void clusterAnalysis(String [] args)throws Exception {
        int nBins = 50;
        int topBin = 100;
        double ratio = (double)topBin/(double)nBins;
        
        ModernBed bed = new ModernBed(new File(args[0]));
        bed.sortSites();
        
        double start = Double.valueOf(args[1]);
        double end =  Double.valueOf(args[2]);
        double inc = Double.valueOf(args[3]);
        
        int previousN = 0;
        for (double delta =start ; delta <= end ; delta = delta + inc){
            Map<String,List<Cluster>> map = bed.cluster(delta);
            int nClusters = 0;
            int[] singles = new int[nBins];
            for (int i=0 ; i<singles.length;++i){
                singles[i]=0;
            }
            for (List<Cluster> clusterList : map.values()){
                nClusters = nClusters + clusterList.size();
                for (Cluster cluster : clusterList){
                    int index = Math.min(nBins-1, (int)(cluster.getPoints().size()/ratio));
                    ++singles[index];

                }

            }
            System.out.printf("%f\t%d\t",delta,nClusters);
            for (int single : singles){
                System.out.printf("%d\t",single);
            }
            System.out.println();
            previousN = nClusters;
            
        }
    }

    // args[0] input bed
    // args[1] epsilon
    // args[3] minPts
    static public String clusterBed(String[] args) throws Exception{
        ModernBed bed = new ModernBed(new File(args[0]));
        String inFile = args[0].substring(0,args[0].indexOf(".bed"));
        String outFile = String.format("%s_%s_%s.clustered.bed",inFile,args[1],args[2]);
        PrintStream stream =new PrintStream(outFile);
        bed.cluster(stream,Double.valueOf(args[1]),Integer.valueOf(args[2]));
        stream.close();
        return outFile;
    }  
    // args[0] input  clustered bed  
    static public String colorCluster(String args) throws Exception{
        ModernBed bed = new ModernBed(new File(args));
        int index = args.lastIndexOf(".clustered.bed");
        String base = args.substring(0,index);        
        String fName = String.format("%s.clustered.colored.bed", base);
        PrintStream stream =new PrintStream(fName);
        bed.clusterColorBed(stream,base,base);
        stream.close();
        return fName;
    }
    // args[0] -unclustered bed file 
    // args[1] clustered bed file
    static public void notClustered(String[] args) throws Exception {
        ModernBed bed = new ModernBed(new File(args[0]));
        ModernBed clusteredBed = new ModernBed(new File(args[1]));
        int index = args[1].lastIndexOf("clustered.bed");
        String base = args[1].substring(0,index);
        String fName = String.format("%sunclustered.bed",base);
        PrintStream stream =new PrintStream(fName);
        stream.printf("track name=\"%sunclustered\" description=\"%s peaks not included in a cluster\" visibility=2 itemRgb=\"On\"\n",base,base);
        clusteredBed.notClustered(bed, stream);
    }

    static public void main(String[] args) throws Exception {
        clusterAnalysis(args);
/*    
        // args[0] - input bed
        // args[1] - epsilon
        // args[2] = minPts
        String clusteredBed = clusterBed(args);
        colorCluster(clusteredBed);
        args[1] = clusteredBed;
        notClustered(args);
        */
    }

    
    // compares two bed records by score
    public class ScoreComparator implements Comparator {

        @Override
        public int compare(Object o1, Object o2) {
            BindSite bs1 = (BindSite)o1;
            BindSite bs2 = (BindSite)o2;
            
            return Double.compare(bs1.getScore(),bs2.getScore());
        }
        
    }
    // compare two bind sites by coordinate
    public class BindSiteComparator implements Comparator {

        @Override
        public int compare(Object o1, Object o2) {
            BindSite bs1 = (BindSite)o1;
            BindSite bs2 = (BindSite)o2;
            return Double.compare(bs1.getCoord(), bs2.getCoord());
        }
        
    }
    // compare two clusters by the coordinate of the first bindsite in each cluster
    public class ClusterComparator implements Comparator {

        @Override
        public int compare(Object o1, Object o2) {
            Cluster c1 = (Cluster)o1;
            List<BindSite> l1 = c1.getPoints();
            Cluster c2 = (Cluster)o2;
            List<BindSite> l2 = c2.getPoints();
            BindSite bs1 = l1.get(0);
            BindSite bs2 = l2.get(0);
            return Double.compare(bs1.getCoord(), bs2.getCoord());
        }
        
    }      

}
