/*
 * To change this license header, choose License Headers in Project Properties.
 * To change this template file, choose Tools | Templates
 * and open the template in the editor.
 */
package org.rhwlab.alignment;

import htsjdk.samtools.SAMFileHeader;
import htsjdk.samtools.SAMFileWriter;
import htsjdk.samtools.SAMFileWriterFactory;
import htsjdk.samtools.SAMRecord;
import htsjdk.samtools.SAMRecord.SAMTagAndValue;
import htsjdk.samtools.SAMRecordIterator;
import htsjdk.samtools.SAMSequenceDictionary;
import htsjdk.samtools.SAMSequenceRecord;
import htsjdk.samtools.SamReader;
import htsjdk.samtools.SamReaderFactory;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Map.Entry;
import java.util.Set;
import java.util.TreeMap;
import java.util.TreeSet;

/**
 *
 * @author gevirl
 */
public class BAM {

    File bam;
    int matchedReads;  // the count of reads that matched the reference names
    int allReads;

    public BAM(File bam) {
        this.bam = bam;
    }

    static public BAM filterRiboZero(String species, BAM bam, String outBam) throws Exception {
        Set refs = BAM.getRibosomalNames(species);
        BAM[] bams = new BAM[1];
        bams[0] = bam;
        return BAM.filterMerge(bams, refs, outBam);
    }

    static public BAM filterPolyA(String species, BAM bam, String outBam) throws Exception {
        Set refs = BAM.getRibosomalNames(species);
        refs.addAll(BAM.getHistoneNames(species));
        BAM[] bams = new BAM[1];
        bams[0] = bam;
        return BAM.filterMerge(bams, refs, outBam);
    }

    // merge an array of BAMS
    static public BAM merge(BAM[] bams, String outBAM) throws Exception {
        return filterMerge(bams, new TreeSet<>(), outBAM);
    }

    // merge the BAMS and filter out a set of read names
    static public BAM filterMerge(BAM[] bams, Set referenceNames, String outBAM) throws Exception {
        SamReader samReader = SamReaderFactory.makeDefault().open(bams[0].bam);
        SAMFileHeader header = samReader.getFileHeader();  // use the first bams header for the output header      

        SAMFileWriterFactory factory = new SAMFileWriterFactory();
        File outBAMFile = new File(outBAM);
        SAMFileWriter writer = factory.makeBAMWriter(header, true, outBAMFile);

        String outName = outBAM.substring(0, outBAM.lastIndexOf(".bam"));
        SAMFileWriter removed = factory.makeBAMWriter(header, true, new File(String.format("%s.culled.bam", outName)));

        for (BAM bamFile : bams) {
            bamFile.removeReads(referenceNames, writer, removed);
        }
        writer.close();
        removed.close();

        return new BAM(outBAMFile);
    }

    // remove alignments from the file given a set of reference names  to remove
    //  alignments not removed are witten to the supplied SAMFileWriter
    // if removed is not null, the removed alignments are written to that SAMFileWriter
    public void removeReads(Set referenceNames, SAMFileWriter writer, SAMFileWriter removed) throws Exception {
        SamReader samReader = SamReaderFactory.makeDefault().open(bam);
        SAMRecordIterator iter = samReader.iterator();
        while (iter.hasNext()) {
            SAMRecord record = iter.next();
            String ref = record.getReferenceName();
            if (ref.contains(":")) {
                ref = ref.split(":")[1];
            }
            if (referenceNames.contains(ref)) {
                if (removed != null) {
                    removed.addAlignment(record);
                }
            } else {
                writer.addAlignment(record);
            }
        }
        samReader.close();
    }

    static public TreeSet<String> formSAMDictionary(String fasta, SAMSequenceDictionary dict) throws Exception {
        TreeSet<String> refNames = new TreeSet<>();

        BufferedReader reader = new BufferedReader(new FileReader(fasta));
        String line = reader.readLine();
        String name = line.substring(1);

        line = reader.readLine();
        int len = 0;
        while (line != null) {
            if (line.charAt(0) == '>') {
                dict.addSequence(new SAMSequenceRecord(name, len));
                refNames.add(name);
                len = 0;
                name = line.substring(1);
            } else {
                len = len + line.length();
            }
            line = reader.readLine();
        }
        dict.addSequence(new SAMSequenceRecord(name, len));
        refNames.add(name);
        reader.close();

        return refNames;
    }

    // make a new bam file which contains the alignments that match transcripts in the given fasta transcriptome file
    public int[] keepReads(File xomeFasta, String outBAM, Set<String> countReferences) throws Exception {
        SAMSequenceDictionary dict = new SAMSequenceDictionary();
        TreeSet<String> refNames = formSAMDictionary(xomeFasta.getPath(), dict);
        BufferedReader reader = new BufferedReader(new FileReader(xomeFasta));
        String line = reader.readLine();
        while (line != null) {
            if (line.charAt(0) == '>') {
                refNames.add(line.substring(1));
            }
            line = reader.readLine();
        }
        reader.close();
        return keepReads(dict, refNames, outBAM, countReferences);
    }

    // write a bam that has alignments with the given set of reference names 
    public int[] keepReads(SAMSequenceDictionary dict, Set referenceNames, String outBAM, Set<String> countReferences) throws Exception {
        TreeSet<String> readNames = new TreeSet<>();
        TreeSet<String> countNames = new TreeSet<>();

        SamReader samReader = SamReaderFactory.makeDefault().open(bam);
        SAMFileHeader header = samReader.getFileHeader().clone();
        header.setSequenceDictionary(dict);
        samReader.close();

        SAMFileWriterFactory factory = new SAMFileWriterFactory();
        File outBAMFile = new File(outBAM);
        SAMFileWriter writer = factory.makeSAMWriter(header.clone(), false, outBAMFile);

        factory = new SAMFileWriterFactory();
        String outName = outBAM.substring(0, outBAM.lastIndexOf(".sam"));
        SAMFileWriter removed = factory.makeSAMWriter(header.clone(), false, new File(String.format("%s.culled.sam", outName)));

        samReader = SamReaderFactory.makeDefault().open(bam);
        SAMRecordIterator iter = samReader.iterator();
        while (iter.hasNext()) {
            SAMRecord record = iter.next();
            String ref = record.getReferenceName();
            if (!referenceNames.contains(ref)) {
                removed.addAlignment(record);
            } else {
                writer.addAlignment(record);
                String name = record.getReadName();
                readNames.add(name);
                if (countReferences.contains(ref)) {
                    countNames.add(name);
                }
            }
        }
        removed.close();
        writer.close();
        samReader.close();
        int[] ret = new int[2];
        ret[0] = readNames.size();
        ret[1] = countNames.size();
        return ret;
    }

    // match the sequence names in a sam header to a set of names
    private void matchHeader(SAMFileHeader header, Set referenceNames) {
        SAMSequenceDictionary dict = header.getSequenceDictionary();
        List<SAMSequenceRecord> inList = dict.getSequences();
        List<SAMSequenceRecord> outList = new ArrayList<SAMSequenceRecord>();
        for (SAMSequenceRecord record : inList) {
            String name = record.getSequenceName();
            if (referenceNames.contains(name)) {
                outList.add(record);
            }
        }
        dict.setSequences(outList);
    }

    public int countHistoneReads() throws Exception {
        Set<String> names = getHistoneNames(null);
        countReads(names);
        return this.matchedReads;
    }

    static public Set<String> getHistoneNames(String species) throws Exception {
        TreeSet<String> ret = new TreeSet<>();
        BufferedReader reader = new BufferedReader(new FileReader("/net/waterston/vol9/References/WS245/HistoneSequenceNames"));
        String line = reader.readLine();
        while (line != null) {
            ret.add(line);
            line = reader.readLine();
        }
        reader.close();
        return ret;
    }

    static public Set<String> getRibosomalNames(String species) {
        TreeSet<String> ret = new TreeSet<>();
        String[] rRNAs = null;
        if (species.equalsIgnoreCase("dmel")) {
            rRNAs = rRNAfly;
        } else {
            rRNAs = rRNAworm;
        }
        for (String rRNA : rRNAs) {
            ret.add(rRNA);
        }
        return ret;
    }

    public int countReads() throws Exception {
        this.countReads(new TreeSet<>());
        return this.allReads;
    }

    // count the reads that belong to a set of references
    public void countReads(Set referenceNames) throws Exception {
        Set<String> match = new TreeSet<>();
        Set<String> all = new TreeSet<>();
        SamReader samReader = SamReaderFactory.makeDefault().open(bam);
        SAMRecordIterator iter = samReader.iterator();
        while (iter.hasNext()) {
            SAMRecord record = iter.next();
            String ref = record.getReferenceName();
            String readName = record.getReadName().trim();

            all.add(readName);
            if (ref.contains(":")) {
                ref = ref.split(":")[1];
            }
            if (referenceNames.contains(ref)) {
                match.add(readName);
            }
        }
        samReader.close();
        this.matchedReads = match.size();
        this.allReads = all.size();

    }

    public TreeMap<String, Integer> histoneReadCounts() throws Exception {
        Set<String> names = getHistoneNames(null);

        return normalizedReadCounts(names, 1000000);
    }

    // count reads for the keys in a map
    public TreeMap<String, Integer> normalizedReadCounts(Set<String> refs, int normalizeTo) throws Exception {
        TreeMap<String, Integer> ret = new TreeMap<>();

        TreeMap<String, TreeMap<Integer, Integer>> counts = new TreeMap<>();  // ref,#alignments,count
        for (String ref : refs) {
            counts.put(ref, new TreeMap<>());
        }
        TreeSet<String> allreads = new TreeSet<>();

        SamReader samReader = SamReaderFactory.makeDefault().open(bam);
        SAMRecordIterator iter = samReader.iterator();
        while (iter.hasNext()) {
            SAMRecord record = iter.next();
            int nh = (Integer) record.getAttribute("NH");  // number of alignments
            if (record.getFirstOfPairFlag()) {  // only counting one read of the pair
                String ref = record.getReferenceName();
                String readName = record.getReadName();
                if (ref.contains(":")) {
                    ref = ref.split(":")[1];
                }
                allreads.add(readName);
                TreeMap<Integer, Integer> nhMap = counts.get(ref);
                if (nhMap != null) {
                    // increment the count
                    Integer c = nhMap.get(nh);
                    if (c == null) {
                        nhMap.put(nh, 1);
                    } else {
                        nhMap.put(nh, c + 1);
                    }
                }
            }
        }
        samReader.close();

        double f = (double)normalizeTo/(double)allreads.size();
        for (String ref : refs){
            double sum = 0.0;
            TreeMap<Integer, Integer> nhMap = counts.get(ref);
            for (Entry e : nhMap.entrySet()){
                int n = (Integer)e.getValue();
                int d = (Integer)e.getKey();
                sum = sum + (double)n/(double)d;
            }
            sum = f*sum;
            ret.put(ref, (int)sum);
        }
        return ret;
    }

    public void countRibosomalReads(String species) throws Exception {
        countReads(getRibosomalNames(species));
    }

    public int getRibosomalReadCount() {
        return this.matchedReads;
    }

    public int getTotalReadCount() {
        return this.allReads;
    }

    public void extractReads(String outBAM, String readNameFile) throws Exception {
        TreeSet<String> readsToExtract = new TreeSet<>();
        BufferedReader reader = new BufferedReader(new FileReader(readNameFile));
        String line = reader.readLine();
        while (line != null) {
            readsToExtract.add(line);
            line = reader.readLine();
        }
        reader.close();

        int count = 0;
        SamReader samReader = SamReaderFactory.makeDefault().open(bam);
        SAMFileHeader header = samReader.getFileHeader();

        SAMFileWriterFactory factory = new SAMFileWriterFactory();
        File outBAMFile = new File(outBAM);
        SAMFileWriter writer = factory.makeBAMWriter(header, true, outBAMFile);

        SAMRecordIterator iter = samReader.iterator();
        while (iter.hasNext()) {
            SAMRecord record = iter.next();
            String readName = record.getReadName();
            if (readsToExtract.contains(readName)) {
                ++count;
                writer.addAlignment(record);
            }
        }
        samReader.close();
        writer.close();
    }

    // removes multimapping reads from a bam file 
    // retyrn the number of reads indluded in the output bam
    static public int removeMultiMappers(File inBam, File outBam) throws Exception {
        HashSet<String> readNames = new HashSet<>();
        SamReader samReader = SamReaderFactory.makeDefault().open(inBam);
        SAMFileHeader header = samReader.getFileHeader();

        SAMFileWriterFactory factory = new SAMFileWriterFactory();
        SAMFileWriter writer = factory.makeBAMWriter(header, true, outBam);
        SAMRecordIterator iter = samReader.iterator();
        while (iter.hasNext()) {
            SAMRecord record = iter.next();
            String readName = record.getReadName();
            Integer nh = (Integer) record.getAttribute("NH");
            if (nh == 1) {
                writer.addAlignment(record);
                readNames.add(readName);
            }
            int oidf = 0;
        }
        writer.close();
        samReader.close();
        return readNames.size();
    }

    static public void main(String[] args) throws Exception {
        int n = removeMultiMappers(
                new File("/net/waterston/vol9/RNASeq/adwarner/hlh1original_DS1/nonhlh1T0_1_1/Merged_1/Aligned.toGenome.sorted.bam"),
                new File("/net/waterston/vol9/RNASeq/adwarner/hlh1original_DS1/nonhlh1T0_1_1/Merged_1/Aligned.toGenome.unique.bam"));
        System.out.println(n);
        //       BAM bam = new BAM(new File("/net/waterston/vol9/RNASeq/jaedap/tm4265_DS1/tm4265_880_1_1/Merged_1/Aligned.toTranscriptome.out.bam"));
        //       bam.countRibosomalReads("CElegans");

        int aoisdfisoj = 0;
    }
    static String[] rRNAworm = {"F31C3.7", "F31C3.8", "F31C3.9", "F31C3.11", "MTCE.7", "MTCE.33", "T09B4.23", "T27C5.18", "Y102A5D.5", "Y102A5D.6", "Y102A5D.7",
        "Y102A5D.8", "Y102A5D.9", "Y102A5D.10", "Y102A5D.11", "Y102A5D.12", "ZK218.12", "ZK218.16", "ZK218.17", "ZK218.18", "ZK218.19", "ZK218.20"};
    static String[] rRNAfly = {
        "FBtr0086345",
        "FBtr0086346",
        "FBtr0086347",
        "FBtr0086349",
        "FBtr0086350",
        "FBtr0086351",
        "FBtr0086352",
        "FBtr0086353",
        "FBtr0086354",
        "FBtr0086356",
        "FBtr0086357",
        "FBtr0086358",
        "FBtr0086359",
        "FBtr0086360",
        "FBtr0086361",
        "FBtr0086362",
        "FBtr0086364",
        "FBtr0086365",
        "FBtr0086366",
        "FBtr0086367",
        "FBtr0086368",
        "FBtr0086369",
        "FBtr0086370",
        "FBtr0086371",
        "FBtr0086372",
        "FBtr0086373",
        "FBtr0086374",
        "FBtr0086375",
        "FBtr0086376",
        "FBtr0086377",
        "FBtr0086378",
        "FBtr0086379",
        "FBtr0086380",
        "FBtr0086381",
        "FBtr0086382",
        "FBtr0086383",
        "FBtr0086384",
        "FBtr0086385",
        "FBtr0086386",
        "FBtr0086387",
        "FBtr0086388",
        "FBtr0086389",
        "FBtr0086390",
        "FBtr0086391",
        "FBtr0086392",
        "FBtr0086393",
        "FBtr0086394",
        "FBtr0086395",
        "FBtr0086396",
        "FBtr0086397",
        "FBtr0086398",
        "FBtr0086399",
        "FBtr0086400",
        "FBtr0086401",
        "FBtr0086402",
        "FBtr0086403",
        "FBtr0086404",
        "FBtr0086405",
        "FBtr0086406",
        "FBtr0086407",
        "FBtr0086409",
        "FBtr0086410",
        "FBtr0086411",
        "FBtr0086412",
        "FBtr0086413",
        "FBtr0086414",
        "FBtr0086415",
        "FBtr0086416",
        "FBtr0086417",
        "FBtr0086418",
        "FBtr0086419",
        "FBtr0086420",
        "FBtr0086421",
        "FBtr0086422",
        "FBtr0086423",
        "FBtr0086424",
        "FBtr0086425",
        "FBtr0086426",
        "FBtr0086427",
        "FBtr0086428",
        "FBtr0086429",
        "FBtr0086430",
        "FBtr0086431",
        "FBtr0086432",
        "FBtr0086433",
        "FBtr0086434",
        "FBtr0086435",
        "FBtr0086436",
        "FBtr0086437",
        "FBtr0086438",
        "FBtr0086439",
        "FBtr0086440",
        "FBtr0086441",
        "FBtr0086442",
        "FBtr0086443",
        "FBtr0086444",
        "FBtr0114187",
        "FBtr0114194",
        "FBtr0114196",
        "FBtr0114198",
        "FBtr0114201",
        "FBtr0114202",
        "FBtr0114205",
        "FBtr0114206",
        "FBtr0114207",
        "FBtr0114208",
        "FBtr0114209",
        "FBtr0114210",
        "FBtr0114211",
        "FBtr0114214",
        "FBtr0114216",
        "FBtr0114218",
        "FBtr0114222",
        "FBtr0114223",
        "FBtr0114228",
        "FBtr0114249",
        "FBtr0114253",
        "FBtr0346874",
        "FBtr0114257",
        "FBtr0114259",
        "FBtr0114261",
        "FBtr0114262",
        "FBtr0114274",
        "FBtr0114275",
        "FBtr0114280",
        "FBtr0114283",
        "FBtr0114284",
        "FBtr0114285",
        "FBtr0114286",
        "FBtr0346901",
        "FBtr0100888",
        "FBtr0100890",
        "FBtr0343433",
        "FBtr0346875",
        "FBtr0346876",
        "FBtr0346878",
        "FBtr0346879",
        "FBtr0346880",
        "FBtr0346882",
        "FBtr0346883",
        "FBtr0346884",
        "FBtr0346885",
        "FBtr0346873",
        "FBtr0346877",
        "FBtr0346881",
        "FBtr0346887",
        "FBtr0346898"};
}
