/*
 * Decompiled with CFR 0.152.
 */
package edu.mit.compbio.flynet;

import edu.mit.compbio.flynet.Edge;
import edu.mit.compbio.flynet.FileExport;
import edu.mit.compbio.flynet.Flynet;
import edu.mit.compbio.flynet.Network;
import java.util.ArrayList;
import java.util.Random;
import weka.classifiers.functions.Logistic;
import weka.core.Attribute;
import weka.core.FastVector;
import weka.core.Instance;
import weka.core.Instances;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class SupervisedLearning {
    private ArrayList<Edge> edges_ = null;
    private int numEdges_ = -1;
    private Network redfly_ = null;
    private Instances instances_ = null;
    private int numPositives_ = -1;
    private double posWeight_;
    private double negWeight_;
    private int randomSeed_ = 2;
    private Random random_ = new Random(this.randomSeed_);
    private int folds_ = 10;
    private int numRepeatsCV_ = 1000;
    ArrayList<double[]> coeffsCV_ = null;

    public double[] run(ArrayList<Edge> edges) {
        this.redfly_ = new Network();
        this.redfly_.load("../networks/redfly.txt", false);
        this.redfly_.setName("redfly.txt");
        this.edges_ = edges;
        this.numEdges_ = this.edges_.size();
        this.coeffsCV_ = new ArrayList();
        Flynet.println("Running " + this.numRepeatsCV_ + " iteration(s) of " + this.folds_ + "-fold cross validation ...");
        this.createInstances();
        int i = 0;
        while (i < this.numRepeatsCV_) {
            this.crossValidation(i);
            ++i;
        }
        this.writeCoeffsCV();
        int numCoeffs = this.coeffsCV_.get(0).length;
        double[] avgCoeffs = new double[numCoeffs];
        int c = 0;
        while (c < numCoeffs) {
            avgCoeffs[c] = 0.0;
            ++c;
        }
        int i2 = 0;
        while (i2 < this.coeffsCV_.size()) {
            int c2 = 0;
            while (c2 < numCoeffs) {
                int n = c2;
                avgCoeffs[n] = avgCoeffs[n] + this.coeffsCV_.get(i2)[c2];
                ++c2;
            }
            ++i2;
        }
        c = 0;
        while (c < numCoeffs) {
            int n = c++;
            avgCoeffs[n] = avgCoeffs[n] / (double)this.coeffsCV_.size();
        }
        Flynet.println("Avgerage coefficients from cross validation:");
        Flynet.println("Intercept  \t" + avgCoeffs[0]);
        Flynet.println("Motif      \t" + avgCoeffs[1]);
        Flynet.println("Chip       \t" + avgCoeffs[2]);
        Flynet.println("ChromatinTC\t" + avgCoeffs[3]);
        Flynet.println("ChromatinCL\t" + avgCoeffs[4]);
        Flynet.println("Microarray \t" + avgCoeffs[5]);
        Flynet.println("RnaSeq     \t" + avgCoeffs[6]);
        Flynet.println("");
        return avgCoeffs;
    }

    private void createInstances() {
        Attribute motif = new Attribute("motif");
        Attribute chip = new Attribute("chip");
        Attribute chromatinTC = new Attribute("chromatinTC");
        Attribute chromatinCL = new Attribute("chromatinCL");
        Attribute microarray = new Attribute("microarray");
        Attribute rnaSeq = new Attribute("rnaSeq");
        FastVector classVal = new FastVector(2);
        classVal.addElement("1");
        classVal.addElement("0");
        Attribute classLabel = new Attribute("RedFly", classVal);
        FastVector attributes = new FastVector(7);
        attributes.addElement(motif);
        attributes.addElement(chip);
        attributes.addElement(chromatinTC);
        attributes.addElement(chromatinCL);
        attributes.addElement(microarray);
        attributes.addElement(rnaSeq);
        attributes.addElement(classLabel);
        this.instances_ = new Instances("RedFly", attributes, this.numEdges_);
        this.instances_.setClassIndex(attributes.size() - 1);
        this.numPositives_ = 0;
        int i = 0;
        while (i < this.numEdges_) {
            double[] features = this.edges_.get(i).getFeatures();
            assert (features.length == attributes.size() - 1);
            Instance instance = new Instance(attributes.size());
            int f = 0;
            while (f < features.length) {
                instance.setValue((Attribute)attributes.elementAt(f), features[f]);
                ++f;
            }
            if (this.redfly_.contains(this.edges_.get(i))) {
                instance.setValue((Attribute)attributes.elementAt(attributes.size() - 1), "1");
                ++this.numPositives_;
            } else {
                instance.setValue((Attribute)attributes.elementAt(attributes.size() - 1), "0");
            }
            this.instances_.add(instance);
            ++i;
        }
        assert (this.instances_.numInstances() == this.numEdges_);
        this.posWeight_ = 1.0 - (double)this.numPositives_ / (double)this.numEdges_;
        this.negWeight_ = (double)this.numPositives_ / (double)this.numEdges_;
        i = 0;
        while (i < this.instances_.numInstances()) {
            Instance inst = this.instances_.instance(i);
            if (this.isPositive(inst)) {
                inst.setWeight(this.posWeight_);
            } else {
                inst.setWeight(this.negWeight_);
            }
            ++i;
        }
    }

    private void crossValidation(int iteration) {
        Instances learningSet = new Instances(this.instances_);
        learningSet.randomize(this.random_);
        learningSet.stratify(this.folds_);
        int n = 0;
        while (n < this.folds_) {
            Instances train = learningSet.trainCV(this.folds_, n);
            Instances test = learningSet.testCV(this.folds_, n);
            assert (this.isBalanced(test));
            assert (this.isBalanced(train));
            Logistic classifier = new Logistic();
            classifier.setRidge(0.0);
            try {
                classifier.buildClassifier(train);
            }
            catch (Exception e) {
                Flynet.error(e);
            }
            double[] coeffs = this.coeffsAsVector(classifier.coefficients());
            this.coeffsCV_.add(coeffs);
            ++n;
        }
    }

    private double[] coeffsAsVector(double[][] coeffs) {
        double[] vect = new double[coeffs.length];
        int i = 0;
        while (i < coeffs.length) {
            assert (coeffs[i].length == 1);
            vect[i] = coeffs[i][0];
            ++i;
        }
        return vect;
    }

    private boolean isPositive(Instance instance) {
        String label = instance.stringValue(instance.classIndex());
        return label.equals("1");
    }

    private boolean isBalanced(Instances set) {
        double expectedNumPositives = (double)(set.numInstances() * this.numPositives_) / (double)this.numEdges_;
        int numP = 0;
        double weightP = 0.0;
        double weightN = 0.0;
        int i = 0;
        while (i < set.numInstances()) {
            if (this.isPositive(set.instance(i))) {
                ++numP;
                weightP += set.instance(i).weight();
            } else {
                weightN += set.instance(i).weight();
            }
            ++i;
        }
        double delta = Math.abs(weightP - weightN);
        if (Math.abs(expectedNumPositives - (double)numP) >= 1.0) {
            return false;
        }
        return !(delta >= this.posWeight_);
    }

    private void writeCoeffsCV() {
        FileExport writer = new FileExport("coeffsCV.txt");
        int i = 0;
        while (i < this.coeffsCV_.size()) {
            double[] coeffs = this.coeffsCV_.get(i);
            String line = Double.toString(coeffs[0]);
            int c = 1;
            while (c < coeffs.length) {
                line = String.valueOf(line) + "\t" + Double.toString(coeffs[c]);
                ++c;
            }
            writer.println(line);
            ++i;
        }
        writer.close();
    }

    public void tests() {
        Attribute one = new Attribute("one");
        FastVector fvClassVal = new FastVector(2);
        fvClassVal.addElement("0");
        fvClassVal.addElement("1");
        Attribute classLabel = new Attribute("Class", fvClassVal);
        FastVector attributes = new FastVector();
        attributes.addElement(one);
        attributes.addElement(classLabel);
        Instances learningSet = new Instances("Learning", attributes, 10);
        learningSet.setClassIndex(attributes.size() - 1);
        this.folds_ = 10;
        this.numPositives_ = 10;
        int numInstances = 100;
        double posWeight = 1.0 - (double)this.numPositives_ / (double)numInstances;
        double negWeight = (double)this.numPositives_ / (double)numInstances;
        int i = 0;
        while (i < 90) {
            Instance negInstance = new Instance(attributes.size());
            negInstance.setValue((Attribute)attributes.elementAt(0), (double)(-(i + 1)));
            negInstance.setValue((Attribute)attributes.elementAt(attributes.size() - 1), "0");
            negInstance.setWeight(negWeight);
            learningSet.add(negInstance);
            ++i;
        }
        i = 0;
        while (i < 10) {
            Instance posInstance = new Instance(attributes.size());
            posInstance.setValue((Attribute)attributes.elementAt(0), (double)(i + 1));
            posInstance.setValue((Attribute)attributes.elementAt(attributes.size() - 1), "1");
            posInstance.setWeight(posWeight);
            learningSet.add(posInstance);
            ++i;
        }
        learningSet.randomize(new Random(this.randomSeed_));
        learningSet.stratify(this.folds_);
        Logistic classifier = new Logistic();
        classifier.setRidge(0.0);
        int n = 0;
        while (n < this.folds_) {
            Instances train = learningSet.trainCV(this.folds_, n);
            Instances test = learningSet.testCV(this.folds_, n);
            assert (this.isBalanced(test));
            assert (this.isBalanced(train));
            ++n;
        }
        try {
            classifier.buildClassifier(learningSet);
            double result = classifier.classifyInstance(learningSet.instance(0));
            double[][] coeffs = classifier.coefficients();
            System.out.println(classifier.toString());
            System.out.println();
        }
        catch (Exception e) {
            Flynet.error(e);
        }
    }
}

