/*
 * To change this license header, choose License Headers in Project Properties.
 * To change this template file, choose Tools | Templates
 * and open the template in the editor.
 */
package org.rhwlab.alignment;

import htsjdk.samtools.Cigar;
import htsjdk.samtools.CigarElement;
import htsjdk.samtools.CigarOperator;
import htsjdk.samtools.SAMFileHeader;
import htsjdk.samtools.SAMFileWriter;
import htsjdk.samtools.SAMFileWriterFactory;
import htsjdk.samtools.SAMRecord;
import htsjdk.samtools.SAMRecordIterator;
import htsjdk.samtools.SamReader;
import htsjdk.samtools.SamReaderFactory;
import htsjdk.samtools.util.Interval;
import htsjdk.samtools.util.IntervalTree;
import htsjdk.samtools.util.IntervalTreeMap;
import java.io.File;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.TreeSet;
import org.rhwlab.gene.model.Annotation;
import org.rhwlab.gene.model.Exon;
import org.rhwlab.gene.model.ModelFromGFF;
import org.rhwlab.gene.model.exons.NonOverlapping;

/**
 *
 * @author gevirl
 */
public class PairedEndBAM extends BAM {

    SamReader samReader;
    SAMRecordIterator iter;
    SAMFileHeader header;
    ArrayList<SAMRecord> records = new ArrayList<>();

    public PairedEndBAM(File bam) {
        super(bam);
        samReader = SamReaderFactory.makeDefault().open(bam);
        iter = samReader.iterator();
        header = samReader.getFileHeader();

    }

    public SAMFileHeader getSamHeader() {
        return this.header;
    }

    public SAMRecord[] nextPair() {
        SAMRecord[] ret = new SAMRecord[2];
        while (true) {
            try {
                ret[0] = iter.next();
                ret[1] = iter.next();
            } catch (Exception exc) {
                return null;
            }
            if (ret[0] == null || ret[1] == null) {
                return null;
            }
            if (ret[0].getReadPairedFlag() && ret[1].getReadPairedFlag() && ret[0].getProperPairFlag() && ret[1].getProperPairFlag()) {
                return ret;
            }
        }
    }

    private boolean areMates(SAMRecord rec1, SAMRecord rec2) {
        if (rec1.getReadName().equals(rec2.getReadName())) {
            if (rec1.getReferenceName().equals(rec2.getReferenceName())) {
                int rec1HI = (Integer) rec1.getAttribute("HI");
                int rec2HI = (Integer) rec2.getAttribute("HI");
                if (rec1HI == rec2HI) {
                    return true;
                }
            }
        }
        return false;
    }

    static public void removeIntronicReads(File file, IntervalTreeMap map,String outFile) {
        PairedEndBAM inputBAM = new PairedEndBAM(file);
        SAMRecord[] reads = inputBAM.nextPair();

        SAMFileWriterFactory factory = new SAMFileWriterFactory();

        SAMFileWriter exonicWriter = factory.makeBAMWriter(inputBAM.getSamHeader(), true, new File(outFile));

        boolean[] exonic = new boolean[2];
        while (reads != null) {
            exonic[0] = true;
            exonic[1] = true;
            for (int r = 0; r < reads.length; ++r) {
                SAMRecord read = reads[r];
                int refPos = read.getAlignmentStart();
                Interval readInterval = new Interval(read.getReferenceName(), read.getAlignmentStart(), read.getAlignmentEnd());
                Collection overlapping = map.getOverlapping(readInterval);  // all exons overalapping the read

                Cigar cig = read.getCigar();
                for (CigarElement ce : cig.getCigarElements()) {
                    boolean elementExonic = true;
                    int len = ce.getLength();

                    CigarOperator op = ce.getOperator();
                    if (op.consumesReferenceBases()) {
                        if (op.consumesReadBases()) {
                            for (int i = 0; i < len; ++i) {
                                boolean positionExonic = false;
                                for (Object obj : overlapping) {
                                    if (obj instanceof Annotation) {
                                        Annotation annot = (Annotation) obj;
                                        if (annot.contains(read.getReferenceName(), refPos + i)) {
                                            positionExonic = true;
                                            break;
                                        }
                                    } else {
                                        List<Annotation> annotList = (List<Annotation>) obj;
                                        for (Annotation annot : annotList) {
                                            if (annot.contains(read.getReferenceName(), refPos + i)) {
                                                positionExonic = true;
                                                break;
                                            }
                                        }
                                    }
                                }
                                if (!positionExonic) {
                                    elementExonic = false;
                                    break;
                                }
                            }
                        }
                        refPos = refPos + len;
                    }
                    if (!elementExonic) {
                        exonic[r] = false;
                        break;
                    }
                }

            }
            if (exonic[0] && exonic[1]) {
                // write read pair to exonic output  file
                exonicWriter.addAlignment(reads[0]);
                exonicWriter.addAlignment(reads[1]);
            } else {
                // write read pair to removed  file
 //               removedWriter.addAlignment(reads[0]);
 //               removedWriter.addAlignment(reads[1]);
            }
            reads = inputBAM.nextPair();
        }
        exonicWriter.close();
        {

        }
 //       removedWriter.close();
    }

    public void close() throws Exception {
        this.samReader.close();
    }
    // count alignment bases on non-overlapping exons
    public static int nonOverlappingExonCounts(File file, NonOverlapping gff) throws Exception {
        IntervalTreeMap map = gff.asTree("exon");
        PairedEndBAM bam = new PairedEndBAM(file);
        SAMRecord[] pair = bam.nextPair();
        int n = 0;
        while (pair != null) {
            boolean summed = false;
            Annotation gene = findGene(pair, map);
            if (gene != null) {
                List<Annotation> exons = gff.getExons(gene.getGeneID());
                IntervalTree tree = new IntervalTree();
                for (Annotation exon : exons) {
                    tree.put(exon.getStart(), exon.getEnd(), exon);
                }
                for (int r = 0; r < pair.length; ++r) {
                    SAMRecord read = pair[r];
                    int refPos = read.getAlignmentStart();
                    Cigar cig = pair[r].getCigar();
                    for (CigarElement ce : cig.getCigarElements()) {
                        int len = ce.getLength();
                        CigarOperator op = ce.getOperator();
                        if (op.consumesReferenceBases()) {
                            if (op.consumesReadBases()) {
                                for (int i = 0; i < len; ++i) {
                                    int pos = refPos + i;
                                    IntervalTree.Node node = tree.minOverlapper(pos, pos);
                                    if (node != null) {
                                        Exon exon = (Exon) node.getValue();
                                        ++exon.count;
                                        summed = true;
                                    }

                                }
                            }
                            refPos = refPos + len;
                        }
                    }
                }
            } else {
//System.out.printf("No Gene: %s - HI=%d\n", pair[0].getReadName(), hi);
            }
            pair = bam.nextPair();
            if (summed){
                ++n;
            }
        }
        bam.close();
        return n;
    }
    // count reads aligning to non-overlapping exons
    public static int nonOverlappingExonReadCounts(File file, NonOverlapping gff) throws Exception {
        IntervalTreeMap map = gff.asTree("exon");
        PairedEndBAM bam = new PairedEndBAM(file);
        SAMRecord[] pair = bam.nextPair();
        int n = 0;
        while (pair != null) {
            boolean summed = false;
            Annotation gene = findGene(pair, map);  // finds the best gene to assign the read pair
            if (gene != null) {
                List<Annotation> exons = gff.getExons(gene.getGeneID());  // get all the exon annotations for the gene
                IntervalTree tree = new IntervalTree();
                for (Annotation exon : exons) {
                    tree.put(exon.getStart(), exon.getEnd(), exon);
                }
                TreeSet<Exon> exonSet = new TreeSet<>();  // this will contain the exons that the read pair overlaps
                for (int r = 0; r < pair.length; ++r) {
                    SAMRecord read = pair[r];
                    int refPos = read.getAlignmentStart();
                    Cigar cig = pair[r].getCigar();
                    for (CigarElement ce : cig.getCigarElements()) {
                        int len = ce.getLength();
                        CigarOperator op = ce.getOperator();
                        if (op.consumesReferenceBases()) {
                            if (op.consumesReadBases()) {
                                for (int i = 0; i < len; ++i) {
                                    int pos = refPos + i;
                                    IntervalTree.Node node = tree.minOverlapper(pos, pos);
                                    if (node != null) {
                                        Exon exon = (Exon) node.getValue();
                                        exonSet.add(exon);
                                        summed = true;
                                    }

                                }
                            }
                            refPos = refPos + len;
                        }
                    }
                }
                for (Exon exon : exonSet){
                    ++exon.count;
                }
            }
            pair = bam.nextPair();
            if (summed){
                ++n;
            }
        }
        bam.close();
        return n;
    }

    // find the best gene to associate with the given alignment
    static Annotation findGene(SAMRecord[] pair, IntervalTreeMap map) {
        for (int r = 0; r < pair.length; ++r) {
            SAMRecord read = pair[r];
            String chromo = read.getReferenceName();
            int refPos = read.getAlignmentStart();
            Cigar cig = read.getCigar();
            ArrayList<Annotation> unique = new ArrayList<>();
            for (CigarElement ce : cig.getCigarElements()) {
                int len = ce.getLength();
                CigarOperator op = ce.getOperator();
                if (op.consumesReferenceBases()) {
                    if (op.consumesReadBases()) {
                        Interval interval = new Interval(chromo, refPos, refPos + len - 1);
                        Collection overlapping = map.getOverlapping(interval);
                        if (overlapping.size() == 1) {
                            Object obj = overlapping.iterator().next();
                            if (obj instanceof Annotation) {
                                Annotation annot = (Annotation) obj;
                                unique.add(annot);
                            }
                        }
                    }
                    refPos = refPos + len;
                }
            }

            // make sure all the exons are fromthe same gene
            if (!unique.isEmpty()) {
                String gene = unique.get(0).getGeneID();
                for (int i = 1; i < unique.size(); ++i) {
                    if (!unique.get(i).equals(gene)) {
//System.out.printf("Gene not unique: %s\n", pair[0].getReadName());
                        return null;
                    }
                }
                return unique.get(0);
            }
        }
        int iusdfuihs=0;
        return null;
    }

}
