/*
 * Decompiled with CFR 0.152.
 */
package com.datumbox.framework.core.machinelearning.modelselection.splitters;

import com.datumbox.framework.common.dataobjects.FlatDataList;
import com.datumbox.framework.core.common.dataobjects.Dataframe;
import com.datumbox.framework.core.common.utilities.PHPMethods;
import com.datumbox.framework.core.machinelearning.common.abstracts.modelselection.AbstractSplitter;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.Random;

public class ShuffleSplitter
extends AbstractSplitter {
    private final double proportion;
    private final int splits;

    public ShuffleSplitter(double proportion, int splits) {
        this.proportion = proportion;
        this.splits = splits;
    }

    public ShuffleSplitter(double proportion, int splits, Random random) {
        super(random);
        this.proportion = proportion;
        this.splits = splits;
    }

    @Override
    public Iterator<AbstractSplitter.Split> split(final Dataframe dataset) {
        final int n = dataset.size();
        if (this.proportion <= 0.0 || this.proportion >= 1.0) {
            throw new IllegalArgumentException("The train size should be between 0.0 and 1.0.");
        }
        final int trainSize = (int)((double)n * this.proportion);
        return new Iterator<AbstractSplitter.Split>(){
            private int counter = 0;

            @Override
            public boolean hasNext() {
                return this.counter < ShuffleSplitter.this.splits;
            }

            @Override
            public AbstractSplitter.Split next() {
                ShuffleSplitter.this.logger.info("Split {}", (Object)this.counter);
                Integer[] ids = new Integer[n];
                int j = 0;
                for (Integer rId : dataset.index()) {
                    ids[j++] = rId;
                }
                PHPMethods.shuffle(ids, ShuffleSplitter.this.random);
                FlatDataList trainIds = new FlatDataList(new ArrayList(trainSize));
                for (int i = 0; i < trainSize; ++i) {
                    trainIds.add((Object)ids[i]);
                }
                FlatDataList testIds = new FlatDataList(new ArrayList(n - trainSize));
                for (int i = trainSize; i < n; ++i) {
                    testIds.add((Object)ids[i]);
                }
                ++this.counter;
                return new AbstractSplitter.Split(dataset.getSubset(trainIds), dataset.getSubset(testIds));
            }
        };
    }
}

