/*
 * Decompiled with CFR 0.152.
 */
package org.broadinstitute.gatk.utils.downsampling;

import htsjdk.variant.variantcontext.Allele;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.BitSet;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.NoSuchElementException;
import java.util.Set;
import java.util.StringTokenizer;
import org.apache.log4j.Logger;
import org.broadinstitute.gatk.utils.BaseUtils;
import org.broadinstitute.gatk.utils.MathUtils;
import org.broadinstitute.gatk.utils.collections.DefaultHashMap;
import org.broadinstitute.gatk.utils.exceptions.GATKException;
import org.broadinstitute.gatk.utils.exceptions.UserException;
import org.broadinstitute.gatk.utils.pileup.PileupElement;
import org.broadinstitute.gatk.utils.pileup.ReadBackedPileup;
import org.broadinstitute.gatk.utils.pileup.ReadBackedPileupImpl;
import org.broadinstitute.gatk.utils.sam.GATKSAMRecord;
import org.broadinstitute.gatk.utils.text.XReadLines;

public class AlleleBiasedDownsamplingUtils {
    public static ReadBackedPileup createAlleleBiasedBasePileup(ReadBackedPileup pileup, double downsamplingFraction) {
        if (downsamplingFraction <= 0.0) {
            return pileup;
        }
        if (downsamplingFraction >= 1.0) {
            return new ReadBackedPileupImpl(pileup.getLocation(), new ArrayList<PileupElement>());
        }
        PileupElementList[] alleleStratifiedElements = new PileupElementList[4];
        for (int i = 0; i < 4; ++i) {
            alleleStratifiedElements[i] = new PileupElementList();
        }
        for (PileupElement pe : pileup) {
            int baseIndex = BaseUtils.simpleBaseToBaseIndex(pe.getBase());
            if (baseIndex == -1) continue;
            alleleStratifiedElements[baseIndex].add(pe);
        }
        int[] alleleCounts = AlleleBiasedDownsamplingUtils.calculateAlleleCounts(alleleStratifiedElements);
        int totalAlleleCount = (int)MathUtils.sum(alleleCounts);
        int numReadsToRemove = (int)((double)totalAlleleCount * downsamplingFraction);
        int[] targetAlleleCounts = AlleleBiasedDownsamplingUtils.runSmartDownsampling(alleleCounts, numReadsToRemove);
        HashSet<PileupElement> readsToRemove = new HashSet<PileupElement>(numReadsToRemove);
        for (int i = 0; i < 4; ++i) {
            PileupElementList alleleList = alleleStratifiedElements[i];
            if (alleleCounts[i] <= targetAlleleCounts[i]) continue;
            readsToRemove.addAll(AlleleBiasedDownsamplingUtils.downsampleElements(alleleList, alleleCounts[i], alleleCounts[i] - targetAlleleCounts[i]));
        }
        ArrayList<PileupElement> readsToKeep = new ArrayList<PileupElement>(totalAlleleCount - numReadsToRemove);
        for (PileupElement pe : pileup) {
            if (readsToRemove.contains(pe)) continue;
            readsToKeep.add(pe);
        }
        return new ReadBackedPileupImpl(pileup.getLocation(), new ArrayList<PileupElement>(readsToKeep));
    }

    private static int[] calculateAlleleCounts(PileupElementList[] alleleStratifiedElements) {
        int[] alleleCounts = new int[alleleStratifiedElements.length];
        for (int i = 0; i < alleleStratifiedElements.length; ++i) {
            alleleCounts[i] = alleleStratifiedElements[i].size();
        }
        return alleleCounts;
    }

    private static int scoreAlleleCounts(int[] alleleCounts) {
        if (alleleCounts.length < 2) {
            return 0;
        }
        int[] alleleCountsCopy = (int[])alleleCounts.clone();
        Arrays.sort(alleleCountsCopy);
        int maxCount = alleleCountsCopy[alleleCounts.length - 1];
        int nextBestCount = alleleCountsCopy[alleleCounts.length - 2];
        int remainderCount = 0;
        for (int i = 0; i < alleleCounts.length - 2; ++i) {
            remainderCount += alleleCountsCopy[i];
        }
        return Math.min(maxCount - nextBestCount + remainderCount, Math.abs(nextBestCount + remainderCount));
    }

    protected static int[] runSmartDownsampling(int[] alleleCounts, int numReadsToRemove) {
        int numAlleles = alleleCounts.length;
        int maxScore = AlleleBiasedDownsamplingUtils.scoreAlleleCounts(alleleCounts);
        int[] alleleCountsOfMax = alleleCounts;
        int numReadsToRemovePerAllele = numReadsToRemove / 2;
        for (int i = 0; i < numAlleles; ++i) {
            for (int j = i; j < numAlleles; ++j) {
                int[] newCounts = (int[])alleleCounts.clone();
                if (i == j) {
                    newCounts[i] = Math.max(0, newCounts[i] - numReadsToRemove);
                } else {
                    newCounts[i] = Math.max(0, newCounts[i] - numReadsToRemovePerAllele);
                    newCounts[j] = Math.max(0, newCounts[j] - numReadsToRemovePerAllele);
                }
                int score = AlleleBiasedDownsamplingUtils.scoreAlleleCounts(newCounts);
                if (score >= maxScore) continue;
                maxScore = score;
                alleleCountsOfMax = newCounts;
            }
        }
        return alleleCountsOfMax;
    }

    protected static List<PileupElement> downsampleElements(List<PileupElement> elements, int originalElementCount, int numElementsToRemove) {
        if (numElementsToRemove == 0) {
            return Collections.emptyList();
        }
        ArrayList<PileupElement> elementsToRemove = new ArrayList<PileupElement>(numElementsToRemove);
        if (numElementsToRemove >= originalElementCount) {
            elementsToRemove.addAll(elements);
            return elementsToRemove;
        }
        BitSet itemsToRemove = new BitSet(originalElementCount);
        for (Integer selectedIndex : MathUtils.sampleIndicesWithoutReplacement(originalElementCount, numElementsToRemove)) {
            itemsToRemove.set(selectedIndex);
        }
        int currentBitSetIndex = 0;
        for (PileupElement element : elements) {
            if (!itemsToRemove.get(currentBitSetIndex++)) continue;
            elementsToRemove.add(element);
        }
        return elementsToRemove;
    }

    public static <A extends Allele> List<GATKSAMRecord> selectAlleleBiasedReads(Map<A, List<GATKSAMRecord>> alleleReadMap, double downsamplingFraction) {
        int totalReads = 0;
        for (List<GATKSAMRecord> reads : alleleReadMap.values()) {
            totalReads += reads.size();
        }
        int numReadsToRemove = (int)((double)totalReads * downsamplingFraction);
        ArrayList<A> alleles = new ArrayList<A>(alleleReadMap.keySet());
        alleles.remove(Allele.NO_CALL);
        int numAlleles = alleles.size();
        int[] alleleCounts = new int[numAlleles];
        for (int i = 0; i < numAlleles; ++i) {
            alleleCounts[i] = alleleReadMap.get(alleles.get(i)).size();
        }
        int[] targetAlleleCounts = AlleleBiasedDownsamplingUtils.runSmartDownsampling(alleleCounts, numReadsToRemove);
        ArrayList<GATKSAMRecord> readsToRemove = new ArrayList<GATKSAMRecord>(numReadsToRemove);
        for (int i = 0; i < numAlleles; ++i) {
            if (alleleCounts[i] <= targetAlleleCounts[i]) continue;
            readsToRemove.addAll(AlleleBiasedDownsamplingUtils.downsampleElements(alleleReadMap.get(alleles.get(i)), alleleCounts[i] - targetAlleleCounts[i]));
        }
        return readsToRemove;
    }

    protected static List<GATKSAMRecord> downsampleElements(List<GATKSAMRecord> reads, int numElementsToRemove) {
        if (numElementsToRemove == 0) {
            return Collections.emptyList();
        }
        ArrayList<GATKSAMRecord> elementsToRemove = new ArrayList<GATKSAMRecord>(numElementsToRemove);
        int originalElementCount = reads.size();
        if (numElementsToRemove >= originalElementCount) {
            elementsToRemove.addAll(reads);
            return elementsToRemove;
        }
        BitSet itemsToRemove = new BitSet(originalElementCount);
        for (Integer selectedIndex : MathUtils.sampleIndicesWithoutReplacement(originalElementCount, numElementsToRemove)) {
            itemsToRemove.set(selectedIndex);
        }
        int currentBitSetIndex = 0;
        for (GATKSAMRecord read : reads) {
            if (!itemsToRemove.get(currentBitSetIndex++)) continue;
            elementsToRemove.add(read);
        }
        return elementsToRemove;
    }

    public static DefaultHashMap<String, Double> loadContaminationFile(File ContaminationFractionFile, Double defaultContaminationFraction, Set<String> AvailableSampleIDs, Logger logger) throws GATKException {
        DefaultHashMap<String, Double> sampleContamination = new DefaultHashMap<String, Double>(defaultContaminationFraction);
        HashSet nonSamplesInContaminationFile = new HashSet(sampleContamination.keySet());
        try {
            XReadLines reader = new XReadLines(ContaminationFractionFile, true);
            for (String line : reader) {
                if (line.length() == 0) continue;
                StringTokenizer st = new StringTokenizer(line, "\t");
                String[] fields = new String[2];
                try {
                    fields[0] = st.nextToken();
                    fields[1] = st.nextToken();
                }
                catch (NoSuchElementException e) {
                    throw new UserException.MalformedFile("Contamination file must have exactly two, tab-delimited columns. Offending line:\n" + line);
                }
                if (st.hasMoreTokens()) {
                    throw new UserException.MalformedFile("Contamination file must have exactly two, tab-delimited columns. Offending line:\n" + line);
                }
                if (fields[0].length() == 0 || fields[1].length() == 0) {
                    throw new UserException.MalformedFile("Contamination file can not have empty strings in either column. Offending line:\n" + line);
                }
                if (sampleContamination.containsKey(fields[0])) {
                    throw new UserException.MalformedFile("Contamination file contains duplicate entries for input name " + fields[0]);
                }
                try {
                    Double contamination = Double.valueOf(fields[1]);
                    if (contamination < 0.0 || contamination > 1.0) {
                        throw new UserException.MalformedFile("Contamination file contains unacceptable contamination value (must be 0<=x<=1): " + line);
                    }
                    if (AvailableSampleIDs == null || AvailableSampleIDs.contains(fields[0])) {
                        sampleContamination.put(fields[0], contamination);
                        continue;
                    }
                    nonSamplesInContaminationFile.add(fields[0]);
                }
                catch (NumberFormatException e) {
                    throw new UserException.MalformedFile("Contamination file contains unparsable double in the second field. Offending line: " + line);
                }
            }
            if (sampleContamination.size() > 0) {
                logger.info(String.format("The following samples were found in the Contamination file and will be processed at the contamination level therein: %s", sampleContamination.keySet().toString()));
                if (AvailableSampleIDs != null) {
                    HashSet<String> samplesNotInContaminationFile = new HashSet<String>(AvailableSampleIDs);
                    samplesNotInContaminationFile.removeAll(sampleContamination.keySet());
                    if (samplesNotInContaminationFile.size() > 0) {
                        logger.info(String.format("The following samples were NOT found in the Contamination file and will be processed at the default contamination level: %s", ((Object)samplesNotInContaminationFile).toString()));
                    }
                }
            }
            if (nonSamplesInContaminationFile.size() > 0) {
                logger.info(String.format("The following entries were found in the Contamination file but were not SAMPLEIDs. They will be ignored: %s", ((Object)nonSamplesInContaminationFile).toString()));
            }
            return sampleContamination;
        }
        catch (IOException e) {
            throw new GATKException("I/O Error while reading sample-contamination file " + ContaminationFractionFile.getName() + ": " + e.getMessage());
        }
    }

    private static final class PileupElementList
    extends ArrayList<PileupElement> {
        private PileupElementList() {
        }
    }
}

