/* 
 * To change this license header, choose License Headers in Project Properties.
 * To change this template file, choose Tools | Templates
 * and open the template in the editor.
 */
package org.rhwlab.gene.model;

import htsjdk.samtools.util.Interval;
import htsjdk.samtools.util.IntervalTree;
import htsjdk.samtools.util.IntervalTreeMap;
import java.io.File;
import java.io.PrintStream;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.Set;
import java.util.TreeMap;
import java.util.TreeSet;

/**
 *
 * @author gevirl
 */
public class AnnotationModel {

    TreeMap<String, List<Annotation>> geneMap = new TreeMap<>();
    TreeMap<String, List<Annotation>> transcriptMap = new TreeMap<>();
    ArrayList<Annotation> all = new ArrayList<>();
    TreeSet<String> chromosomes = new TreeSet<>();
    int count;

    public List<Annotation> getAll() {
        return this.all;
    }

    public IntervalTreeMap asTree(String type) {
        return asTree(type, null, null);
    }

    // map of genomic interval to annotation(s) 
    public IntervalTreeMap asTree(String type, String chromo, String attrValue) {
        int n = 0;
        IntervalTreeMap ret = new IntervalTreeMap();
        for (Annotation annot : this.getAll()) {
            if (annot.getType().equals(type)) {
                if (chromo == null || annot.getChromosome().equals(chromo)) {
                    if (attrValue == null || annot.hasAttributeValue(attrValue)) {
                        boolean strand = annot.getStrand().equals("-");
                        Interval interval = new Interval(annot.getChromosome(), annot.getStart(), annot.getEnd(), strand, null);
                        Object prior = ret.get(interval);
                        if (prior == null) {
                            ret.put(interval, annot);
                            ++n;
                        } else if (prior instanceof Annotation) {
                            ArrayList<Annotation> list = new ArrayList<>();
                            list.add((Annotation) prior);
                            list.add(annot);
                            ++n;
                            ret.put(interval, list);
                        } else {
                            ((List<Annotation>) prior).add(annot);
                            ++n;
                        }
                    }
                }
            }
        }
        
        return ret;
    }

    // make interval trees from an interval map for each chromosome
    public TreeMap<String,IntervalTree> intervalTreesFromMap(IntervalTreeMap intervalMap) {
        TreeMap<String,IntervalTree> ret = new TreeMap<>();
        
        for (Object obj : intervalMap.keySet()) {
            Interval geneInterval = (Interval) obj;
            String chromo = geneInterval.getContig();
            Object annot = intervalMap.get(geneInterval);
            IntervalTree intervalTree = ret.get(chromo);
            if (intervalTree == null){
                intervalTree = new IntervalTree();
                ret.put(chromo,intervalTree);
            }
            intervalTree.put(geneInterval.getStart(), geneInterval.getEnd(), annot);
        }
        return ret;
    }

    public void addAnnotation(Annotation annot, String gene, Object transcript) {

        if (gene != null) {
            List<Annotation> annots = geneMap.get(gene);
            if (annots == null) {
                annots = new ArrayList<>();
                geneMap.put(gene, annots);
            }
            annots.add(annot);

            ++count;
            annot.setGene(gene);
        }

        if (transcript != null) {
            if (transcript instanceof String) {
                addTranscriptAnnotation((String) transcript, annot);
            } else {
                for (Object obj : (List<String>) transcript) {
                    addTranscriptAnnotation((String) obj, annot);
                }
            }
            annot.setTranscript(transcript);
        }
        all.add(annot);
        chromosomes.add(annot.chromo);
    }

    public Set<String> getChromosomes() {
        return chromosomes;
    }

    private void addTranscriptAnnotation(String transcript, Annotation annot) {
        List<Annotation> annots = transcriptMap.get(transcript);
        if (annots == null) {
            annots = new ArrayList<>();
            transcriptMap.put(transcript, annots);
        }
        annots.add(annot);
    }

    public List<Annotation> getGeneAnnotations(String gene) {
        return geneMap.get(gene);
    }

    public Annotation getGeneAnnotation(String wbGene) {
        List<Annotation> annots = this.getGeneAnnotations(wbGene);
        if (annots != null) {
            for (Annotation annot : annots) {
                if (annot.getType().equals("gene")) {
                    return annot;
                }
            }
        }
        return null;
    }

    public List<Annotation> getTranscriptAnnotations(String transcript) {
        return transcriptMap.get(transcript);
    }

    public Annotation getTranscriptAnnotation(String transcript) {
        List<Annotation> annotList = this.getTranscriptAnnotations(transcript);
        if (annotList == null) {
            int isdhuish = 0;
        }
        for (Annotation annot : this.getTranscriptAnnotations(transcript)) {
            Object id = annot.getAttributeValue("ID");
            if (id != null) {
                String idStr = (String) id;
                if (idStr.contains("Transcript")) {
                    return annot;
                }
            }
        }
        return null;
    }

    public Set<String> getAllGenes() {
        return geneMap.keySet();
    }

    public Set<String> getAllTranscripts() {
        return transcriptMap.keySet();
    }

    public String getGeneForTranscript(String transcript) {
        List<Annotation> list = this.getTranscriptAnnotations(transcript);
        for (Annotation annot : list) {
            if (annot.getType().equals("gene")) {
                return "";
            }
        }
        return "";
    }

    public List<Annotation> getGenesByBiotype(String biotype) {
        ArrayList<Annotation> ret = new ArrayList<>();
        for (String wbGene : this.getAllGenes()) {
            Annotation geneAnnot = this.getGeneAnnotation(wbGene);
            if (geneAnnot != null) {
                Object annotType = geneAnnot.getAttributeValue("biotype");
                if (annotType == null) {
                    annotType = geneAnnot.getAttributeValue("gene_biotype");
                }
                if (biotype.equals(annotType)) {
                    ret.add(geneAnnot);
                }
            }
        }
        return ret;
    }

    // returns only one copy of each exon defined
    public List<Annotation> getExons(String wbGene) {
        ArrayList<Annotation> ret = new ArrayList<>();
        for (Annotation annot : this.getGeneAnnotations(wbGene)) {
            if (annot.getType().equals("exon")) {
                ret.add(annot);
            }
        }
        if (!ret.isEmpty()) {
            sortAnnotations(ret);
            // remove any duplicate exons
            int same = Exon.sameExons(ret);
            while (same != -1) {
                ret.remove(same);
                same = Exon.sameExons(ret);
            }
            int i = 1;
            for (Annotation exon : ret) {
                exon.setID(Integer.toString(i));
                ++i;
            }
        }

        return ret;
    }

    static public void sortAnnotations(List<Annotation> exons) {
        exons.sort(new Comparator() {
            @Override
            public int compare(Object o1, Object o2) {
                Annotation a1 = (Annotation) o1;
                Annotation a2 = (Annotation) o2;
                int ret = Integer.compare(a1.getStart(), a2.getStart());
                if (ret == 0) {
                    ret = Integer.compare(a1.getLength(), a2.getLength());
                }
                return ret;
            }

        });
    }

    public List<AnnotationBoundary> getExonBoundaries(String wbGene) {
        List<AnnotationBoundary> ret = new ArrayList<>();

        List<Annotation> exons = this.getExons(wbGene);
        for (Annotation exon : exons) {
            ret.add(new AnnotationStart(exon.start));
            ret.add(new AnnotationEnd(exon.end));
        }

        ret.sort(AnnotationBoundary.getComparator());
        return ret;
    }

    public void save(PrintStream stream) {
        for (Annotation annot : all) {
            annot.print(stream);
        }
    }

    public void toBedFile(Set<String> bioTypes, File bedFile) throws Exception {
        PrintStream stream = new PrintStream(bedFile);
        // sort the protein coding genes by chromosome,start
        TreeMap<String, TreeMap<Integer, Annotation>> sortedGenes = new TreeMap<>();
        for (String wbGene : this.getAllGenes()) {
            Annotation geneAnnot = this.getGeneAnnotation(wbGene);
            String bioType = (String) geneAnnot.getAttributeValue("gene_biotype");
            if (bioTypes.contains(bioType)) {
                String geneName = (String) geneAnnot.getAttributeValue("gene_name");
                String chromo = geneAnnot.getChromosome();
                int start = geneAnnot.getStart();
                TreeMap<Integer, Annotation> chromoAnnots = sortedGenes.get(chromo);
                if (chromoAnnots == null) {
                    chromoAnnots = new TreeMap<>();
                    sortedGenes.put(chromo, chromoAnnots);
                }
                chromoAnnots.put(start, geneAnnot);

            }
        }

        //make a bed record for each transcript
        for (String chrom : sortedGenes.keySet()) {
            TreeMap<Integer, Annotation> chromoAnnots = sortedGenes.get(chrom);
            for (Integer start : chromoAnnots.keySet()) {
                Annotation geneAnnot = chromoAnnots.get(start);
                List<Annotation> list = this.getGeneAnnotations(geneAnnot.getGeneID());
                for (ExtendedBed bed : transcriptBedRecords(list)) {
                    stream.println(bed.toString());
                }
            }
        }
        stream.close();
    }

    // make a bedrecord for each transcript in a gene, given all the gene annotations
    public List<ExtendedBed> transcriptBedRecords(List<Annotation> annots) {
        String geneName = null;
        ExtendedBed rec = null;
        Annotation transcriptAnnot = null;
        ArrayList<ExtendedBed> recs = new ArrayList<>();

        for (Annotation annot : annots) {
            if (annot.getType().equals("gene")) {
                geneName = (String) annot.getAttributeValue("gene_name");
            } else if (annot.getType().equals("transcript")) {
                // start a new bed reord
                transcriptAnnot = annot;
                rec = new ExtendedBed(annot.getChromosome(), annot.getStart() - 1, annot.getEnd(), annot.getTranscript(), 0, annot.getStrand());
                rec.setFieldValue("GeneName", geneName);
                recs.add(rec);
            } else if (annot.getType().equals("exon")) {
                int len = annot.getEnd() - annot.getStart();
                if (len == 0) {
                    int ushus = 0;
                }
                rec.addBlock(1 + annot.getEnd() - annot.getStart(), annot.getStart() - transcriptAnnot.getStart());

            } else if (annot.getType().equals("start_codon")) {
                rec.setThickStart(annot.getStart());
            } else if (annot.getType().equals("stop_codon")) {
                rec.setThickEnd(annot.getEnd());
            }
        }
        return recs;
    }

    public static void main(String[] args) throws Exception {
        //       ModelFromGFF gff3 = new ModelFromGFF(new File("/net/waterston/vol9/References/WS245/AllWormBase.withTransposon.gff3"));
        //       ModelFromGFF gff3 = new ModelFromGFF(new File("/net/waterston/vol9/References/WS260/c_elegans.PRJNA13758.WS260.annotations.WormBase.gff3")); 
        ModelFromGFF gff3 = new ModelFromGFF(new File("/net/waterston/vol2/home/gevirl/Downloads/c_elegans.PRJNA13758.WS264.annotations.wormabse.gff3"));
        int hsdfertwersd = 0;
        String[] names = gff3.geneNameTriplet("apl-1");
        List<Annotation> annots = gff3.getTranscriptsForGene(names[2]);

        int klasdfoisdf = 0;
    }
}
