/*
 * To change this license header, choose License Headers in Project Properties.
 * To change this template file, choose Tools | Templates
 * and open the template in the editor.
 */
package org.rhwlab.chipseq;

import htsjdk.samtools.util.Interval;
import htsjdk.samtools.util.IntervalTree;
import htsjdk.samtools.util.IntervalTree.Node;
import htsjdk.samtools.util.IntervalTreeMap;
import java.io.PrintStream;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;
import java.util.TreeMap;
import org.rhwlab.gene.model.Annotation;
import org.rhwlab.tfs.AllTFs;

/**
 *
 * @author gevirl
 */
public class PeakCluster {

    ArrayList<Peak> peaks = new ArrayList<>();
    String chromo;
    int minPos;
    int maxPos;
    int meanPos;
    Map<String, Short> tfs = new TreeMap<>();

    String intraGene;

    String upGene;
    String upStrand;
    int upDist;

    String downGene;
    String downStrand;
    int downDist;
    
    
    String targetGene;
    String targetTranscript;


    public PeakCluster(String chromo) {
        this.chromo = chromo;
    }
    public PeakCluster(String[] heads,String[] tokens){
        this(heads,tokens,null);
    }
    // construct the PeakCluster from a cluster report line
    public PeakCluster(String[] heads,String[] tokens,AllTFs all){
        
        TreeMap<String,Integer> map = new TreeMap<>();
        for (int i=0 ; i<heads.length ; ++i){
            if (heads[i].contains("_")){
                map.put(heads[i].split("_")[0], i);
            }else {
                map.put(heads[i],i);
            }
        }
        chromo = tokens[map.get("Chromo")];
        minPos = Integer.valueOf(tokens[map.get("MinPos")]);
        maxPos = Integer.valueOf(tokens[map.get("MaxPos")]);
        meanPos = Integer.valueOf(tokens[map.get("MeanPos")]);
        intraGene = tokens[map.get("IntraGene")];
        upGene = tokens[map.get("UpGene")];
        upStrand = tokens[map.get("UpStrand")];
        upDist = Integer.valueOf(tokens[map.get("UpDist")]);
        downGene = tokens[map.get("DownGene")];
        downStrand = tokens[map.get("DownStrand")];
        downDist = Integer.valueOf(tokens[map.get("DownDist")]);        
        targetGene = tokens[map.get("Target")];
        String[] tfNames = tokens[map.get("TF")].split(",");
        for (String tf : tfNames){
            if (all == null || all.containsGeneName(tf)){
                tfs.put(tf, Short.valueOf(tokens[map.get(tf)]));
            }
        }
    }
    public String getChromosome(){
        return this.chromo;
    }    
    public String getID(){
        return String.format("%s_%d",this.chromo,this.meanPos);
    }
    public String getTarget(){
        return targetGene;
    }
    public int getMeanPosition(){
        return this.meanPos;
    }

    public String getUpGene(){
        return this.upGene;
    }
    public String getDownGene(){
        return this.downGene;
    }
    public String getIntraGene(){
        return this.intraGene;
    }
    public String upOrientation(){
        if (upIs5_Prime()){
            return "5P";
        } else {
            return "3P";
        }
    }
    public String downOrientation(){
        if (downIs5_Prime()){
            return "5P";
        } else {
            return "3P";
        }        
    }
    public boolean upIs5_Prime(){
        return this.upStrand.equals("+");
    }
    public boolean downIs5_Prime(){
        return this.downStrand.equals("-");
    }    
    public Set<String> getTFs(){
        return this.tfs.keySet();
    }
  
    public boolean containsTF(String tf){
        return this.tfs.keySet().contains(tf);
    }
    public void reportAsBedFormat(PrintStream stream){
        stream.printf("%s\t%d\t%d\t%s\t0\t.\t%d\t%d\t%s\n", 
                chromo,minPos,maxPos,possibleTarget(),minPos,maxPos,rgb());
    }
    public String rgb(){
        if (peaks.size()==1){
            return "255,0,0";
        }
        if (peaks.size()<=50){
            return "0,255,0";
        }
        return "0,0,255";
    }
    public void report(PrintStream stream, Set<String> allTFs, String sep) {
        stream.printf("%s%s%d%s%d%s%d%s%d%s%d%s%s%s%s%s%s%s%s%s%s%s%d%s%s%s%s%s%d",
                chromo, sep, minPos, sep, meanPos, sep, maxPos, sep, peaks.size(), sep, tfs.size(), sep, possibleTarget(),sep,targetTranscript, sep, intraGene, sep, downGene, sep, downStrand, sep, downDist, sep, upGene, sep, upStrand, sep, upDist);
        StringBuilder builder = new StringBuilder();
        boolean first = true;
        for (String tf : allTFs) {
            Short n = tfs.get(tf);
            if (n == null) {
                n = 0;
            } else {

                if (!first) {
                    builder.append(",");
                }
                builder.append(tf);
                first = false;
            }
            stream.printf("%s%d", sep, n);
        }
        stream.printf("%s%s\n", sep, builder.toString());
    }

    static public void reportHeader(PrintStream stream, Set<String> allTFs, String sep) {
        stream.printf("Chromo%sMinPos%sMeanPos%sMaxPos%sPeaks%sTFs%sTarget%sTranscript%sIntraGene%sDownGene%sDownStrand%sDownDist%sUpGene%sUpStrand%sUpDist",
                sep, sep, sep, sep, sep, sep, sep, sep, sep, sep, sep, sep, sep,sep);
        int i = 16;
        for (String tf : allTFs) {
            stream.printf("%s%s_%d", sep, tf, i);
            ++i;
        }
        stream.printf("%sTF_Names\n", sep);
    }

    public void addPeak(Peak peak) {
        peaks.add(peak);
    }

    public void setTranscript(String t){
        this.targetTranscript = t;
    }
    // determines the up and downstream genes for this cluster
    public void locateInGenome(IntervalTreeMap geneTreeMap, IntervalTree geneTree) {

        minPos = Integer.MAX_VALUE;
        maxPos = Integer.MIN_VALUE;
        double sum = 0;
        long s = 0;
        for (Peak peak : peaks) {
            Short count = tfs.get(peak.tf);
            if (count == null) {
                count = 1;
                tfs.put(peak.tf, count);
            } else {
                ++count;
                tfs.put(peak.tf, count);
            }
            sum = sum + peak.apex;
            s = s + peak.apex;
            if (peak.apex > maxPos) {
                maxPos = peak.apex;
            }
            if (peak.apex < minPos) {
                minPos = peak.apex;
            }
        }
        meanPos = (int) ((double) s / (double) peaks.size());
        String clusterID = String.format("%s_%d",chromo,meanPos);
        for (Peak peak : peaks) {
            peak.getBedRecord().setCluster(clusterID);
            peak.getBedRecord().setClusterSize(this.tfs.size());
        }
        
        // is this cluster intragenic?
        Interval peakInterval = new Interval(this.chromo, this.meanPos, this.meanPos);
        Collection overlappingGene = geneTreeMap.getOverlapping(peakInterval);
        if (!overlappingGene.isEmpty()) {
            Annotation geneAnnot = (Annotation) overlappingGene.iterator().next();
            this.intraGene = geneAnnot.getGeneName();
        }
        // find upstream gene
        Node upnode = geneTree.min(this.meanPos+1, meanPos+1);
        if (upnode != null) {
            Annotation upAnnot = (Annotation) upnode.getValue();
            upGene = upAnnot.getGeneName();
            upStrand = upAnnot.getStrand();
            upDist = upAnnot.getStart() - meanPos;
        }

        Iterator iter = geneTree.reverseIterator(meanPos, meanPos);
        while (iter.hasNext()) {
            IntervalTree.Node downnode = (Node) iter.next();
            Annotation downAnnot = (Annotation) downnode.getValue();
            if (downAnnot.getEnd() < meanPos) {
                downGene = downAnnot.getGeneName();
                downStrand = downAnnot.getStrand();
                downDist = meanPos - downAnnot.getEnd();
                break;
            }
        }
        if (downGene!=null && intraGene!=null && downGene.equals(intraGene)){
            int sd = 0;
        }
        possibleTarget();
    }

    public String possibleTarget(){
        this.targetGene = pickPossibleTarget();
        // annotate all the peak bed records
        for (Peak peak : peaks){
            peak.getBedRecord().setTarget(targetGene);
        }
        return this.targetGene;
    }
    private String pickPossibleTarget() {
        if (this.intraGene != null) {
            return intraGene;
        }
        if (downGene == null) {
            return upGene;
        }
        if (upGene == null){
            return downGene;
        }
        if ((this.downStrand.equals("+")) && this.upStrand.equals("+")) {
            return upGene;
        }
        if (this.downStrand.equals("-") && (this.upStrand.equals("-"))) {
            return downGene;
        }
        if (this.downStrand.equals("-") && this.upStrand.equals("+")) {
            if (downDist < upDist) {
                return downGene;
            } else {
                return upGene;
            }
        }
        if (this.downStrand.equals("+") && this.upStrand.equals("-")) {
            if (downDist < upDist) {
                return downGene;
            } else {
                return upGene;
            }
        }
        return null;
    }
}
