/*
 * Decompiled with CFR 0.152.
 */
package jsat.regression;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Random;
import jsat.classifiers.DataPoint;
import jsat.datatransform.DataTransformProcess;
import jsat.exceptions.UntrainedModelException;
import jsat.math.OnLineStatistics;
import jsat.regression.RegressionDataSet;
import jsat.regression.Regressor;
import jsat.regression.WarmRegressor;
import jsat.regression.evaluation.RegressionScore;
import jsat.utils.concurrent.ParallelUtils;
import jsat.utils.random.RandomUtil;

public class RegressionModelEvaluation {
    private Regressor regressor;
    private RegressionDataSet dataSet;
    private boolean parallel;
    private OnLineStatistics sqrdErrorStats;
    private long totalTrainingTime = 0L;
    private long totalClassificationTime = 0L;
    private DataTransformProcess dtp;
    private Map<RegressionScore, OnLineStatistics> scoreMap;
    private boolean keepModels = false;
    private Regressor[] keptModels;
    private Regressor[] warmModels;

    public RegressionModelEvaluation(Regressor regressor, RegressionDataSet dataSet, boolean parallel) {
        this.regressor = regressor;
        this.dataSet = dataSet;
        this.parallel = parallel;
        this.dtp = new DataTransformProcess();
        this.scoreMap = new LinkedHashMap<RegressionScore, OnLineStatistics>();
    }

    public RegressionModelEvaluation(Regressor regressor, RegressionDataSet dataSet) {
        this(regressor, dataSet, false);
    }

    public void setKeepModels(boolean keepModels) {
        this.keepModels = keepModels;
    }

    public boolean isKeepModels() {
        return this.keepModels;
    }

    public Regressor[] getKeptModels() {
        return this.keptModels;
    }

    public void setWarmModels(Regressor ... warmModels) {
        this.warmModels = warmModels;
    }

    public void setDataTransformProcess(DataTransformProcess dtp) {
        this.dtp = dtp.clone();
    }

    public void evaluateCrossValidation(int folds) {
        this.evaluateCrossValidation(folds, RandomUtil.getRandom());
    }

    public void evaluateCrossValidation(int folds, Random rand) {
        if (folds < 2) {
            throw new UntrainedModelException("Model could not be evaluated because " + folds + " is < 2, and not valid for cross validation");
        }
        List<RegressionDataSet> lcds = this.dataSet.cvSet(folds, rand);
        this.evaluateCrossValidation(lcds);
    }

    public void evaluateCrossValidation(List<RegressionDataSet> lcds) {
        ArrayList<RegressionDataSet> trainCombinations = new ArrayList<RegressionDataSet>(lcds.size());
        for (int i = 0; i < lcds.size(); ++i) {
            trainCombinations.add(RegressionDataSet.comineAllBut(lcds, i));
        }
        this.evaluateCrossValidation(lcds, trainCombinations);
    }

    public void evaluateCrossValidation(List<RegressionDataSet> lcds, List<RegressionDataSet> trainCombinations) {
        this.sqrdErrorStats = new OnLineStatistics();
        this.totalClassificationTime = 0L;
        this.totalTrainingTime = 0L;
        for (int i = 0; i < lcds.size(); ++i) {
            RegressionDataSet trainSet = trainCombinations.get(i);
            RegressionDataSet testSet = lcds.get(i);
            this.evaluationWork(trainSet, testSet, i);
        }
    }

    public void evaluateTestSet(RegressionDataSet testSet) {
        this.sqrdErrorStats = new OnLineStatistics();
        this.totalClassificationTime = 0L;
        this.totalTrainingTime = 0L;
        this.evaluationWork(this.dataSet, testSet, 0);
    }

    private void evaluationWork(RegressionDataSet trainSet, RegressionDataSet testSet, int index) {
        RegressionScore score;
        trainSet = trainSet.shallowClone();
        DataTransformProcess curProccess = this.dtp.clone();
        curProccess.learnApplyTransforms(trainSet);
        long startTrain = System.currentTimeMillis();
        Regressor regressorTouse = this.regressor.clone();
        if (this.warmModels != null && regressorTouse instanceof WarmRegressor) {
            WarmRegressor wr = (WarmRegressor)regressorTouse;
            wr.train(trainSet, this.warmModels[index], this.parallel);
        } else {
            regressorTouse.train(trainSet, this.parallel);
        }
        this.totalTrainingTime += System.currentTimeMillis() - startTrain;
        if (this.keptModels != null) {
            this.keptModels[index] = regressorTouse;
        }
        HashMap<RegressionScore, RegressionScore> scoresToUpdate = new HashMap<RegressionScore, RegressionScore>();
        for (Map.Entry<RegressionScore, OnLineStatistics> entry : this.scoreMap.entrySet()) {
            score = entry.getKey().clone();
            score.prepare();
            scoresToUpdate.put(score, score);
        }
        ParallelUtils.run(this.parallel, testSet.getSampleSize(), (start, end) -> {
            long localPredictionTime = 0L;
            OnLineStatistics localSqrdErrors = new OnLineStatistics();
            HashSet<RegressionScore> localScores = new HashSet<RegressionScore>();
            for (Map.Entry entry : scoresToUpdate.entrySet()) {
                localScores.add(((RegressionScore)entry.getKey()).clone());
            }
            for (int i = start; i < end; ++i) {
                DataPoint di = testSet.getDataPoint(i);
                double trueVal = testSet.getTargetValue(i);
                DataPoint tranDP = curProccess.transform(di);
                long startTime = System.currentTimeMillis();
                double predVal = regressorTouse.regress(tranDP);
                localPredictionTime += System.currentTimeMillis() - startTime;
                double sqrdError = Math.pow(trueVal - predVal, 2.0);
                for (RegressionScore score : localScores) {
                    score.addResult(predVal, trueVal, di.getWeight());
                }
                localSqrdErrors.add(sqrdError, di.getWeight());
            }
            OnLineStatistics onLineStatistics = this.sqrdErrorStats;
            synchronized (onLineStatistics) {
                this.sqrdErrorStats.add(localSqrdErrors);
                this.totalClassificationTime += localPredictionTime;
                for (RegressionScore score : localScores) {
                    ((RegressionScore)scoresToUpdate.get(score)).addResults(score);
                }
            }
        });
        for (Map.Entry<RegressionScore, OnLineStatistics> entry : this.scoreMap.entrySet()) {
            score = entry.getKey().clone();
            score.prepare();
            score.addResults((RegressionScore)scoresToUpdate.get(score));
            entry.getValue().add(score.getScore());
        }
    }

    public void addScorer(RegressionScore scorer) {
        this.scoreMap.put(scorer, new OnLineStatistics());
    }

    public OnLineStatistics getScoreStats(RegressionScore score) {
        return this.scoreMap.get(score);
    }

    public void prettyPrintRegressionScores() {
        int nameLength = 10;
        for (Map.Entry<RegressionScore, OnLineStatistics> entry : this.scoreMap.entrySet()) {
            nameLength = Math.max(nameLength, entry.getKey().getName().length() + 2);
        }
        String pfx = "%-" + nameLength;
        for (Map.Entry<RegressionScore, OnLineStatistics> entry : this.scoreMap.entrySet()) {
            System.out.printf(pfx + "s %-5f (%-5f)\n", entry.getKey().getName(), entry.getValue().getMean(), entry.getValue().getStandardDeviation());
        }
    }

    public double getMinError() {
        return this.sqrdErrorStats.getMin();
    }

    public double getMaxError() {
        return this.sqrdErrorStats.getMax();
    }

    public double getMeanError() {
        return this.sqrdErrorStats.getMean();
    }

    public double getErrorStndDev() {
        return this.sqrdErrorStats.getStandardDeviation();
    }

    public long getTotalTrainingTime() {
        return this.totalTrainingTime;
    }

    public long getTotalClassificationTime() {
        return this.totalClassificationTime;
    }

    public Regressor getRegressor() {
        return this.regressor;
    }
}

