/*
 * Decompiled with CFR 0.152.
 */
package timeseriesweka.classifiers.ensembles.voting.stacking;

import java.util.ArrayList;
import timeseriesweka.classifiers.ensembles.EnsembleModule;
import timeseriesweka.classifiers.ensembles.voting.ModuleVotingScheme;
import weka.classifiers.Classifier;
import weka.core.Attribute;
import weka.core.Instance;
import weka.core.Instances;

public abstract class AbstractStacking
extends ModuleVotingScheme {
    protected Classifier classifier;
    protected int numOutputAtts;
    protected Instances instsHeader;

    public AbstractStacking(Classifier classifier) {
        this.classifier = classifier;
    }

    public AbstractStacking(Classifier classifier, int numClasses) {
        this.classifier = classifier;
        this.numClasses = numClasses;
    }

    public Classifier getClassifier() {
        return this.classifier;
    }

    @Override
    public void trainVotingScheme(EnsembleModule[] modules, int numClasses) throws Exception {
        this.numClasses = numClasses;
        this.setNumOutputAttributes(modules);
        int numInsts = modules[0].trainResults.numInstances();
        this.initInstances();
        Instances insts = new Instances(this.instsHeader, numInsts);
        for (int i = 0; i < numInsts; ++i) {
            insts.add(this.buildInst(modules, true, i));
        }
        this.classifier.buildClassifier(insts);
    }

    protected abstract void setNumOutputAttributes(EnsembleModule[] var1) throws Exception;

    protected abstract Instance buildInst(double[][] var1, Double var2) throws Exception;

    protected Instance buildInst(EnsembleModule[] modules, boolean train, int instIndex) throws Exception {
        double[][] dists = new double[modules.length][];
        for (int m = 0; m < modules.length; ++m) {
            dists[m] = train ? modules[m].trainResults.getDistributionForInstance(instIndex) : modules[m].testResults.getDistributionForInstance(instIndex);
            for (int c = 0; c < this.numClasses; ++c) {
                double[] dArray = dists[m];
                int n = c;
                dArray[n] = dArray[n] * (modules[m].priorWeight * modules[m].posteriorWeights[c]);
            }
        }
        Double classVal = train ? Double.valueOf(modules[0].trainResults.getTrueClassValue(instIndex)) : null;
        return this.buildInst(dists, classVal);
    }

    protected void initInstances() {
        ArrayList<Attribute> atts = new ArrayList<Attribute>(this.numOutputAtts);
        for (int i = 0; i < this.numOutputAtts - 1; ++i) {
            atts.add(new Attribute("" + i));
        }
        ArrayList<String> classVals = new ArrayList<String>(this.numClasses);
        for (int i = 0; i < this.numClasses; ++i) {
            classVals.add("" + i);
        }
        atts.add(new Attribute("class", classVals));
        this.instsHeader = new Instances("", atts, 1);
        this.instsHeader.setClassIndex(this.numOutputAtts - 1);
    }

    @Override
    public double[] distributionForTrainInstance(EnsembleModule[] modules, int trainInstanceIndex) throws Exception {
        Instance inst = this.buildInst(modules, true, trainInstanceIndex);
        return this.classifier.distributionForInstance(inst);
    }

    @Override
    public double[] distributionForTestInstance(EnsembleModule[] modules, int testInstanceIndex) throws Exception {
        Instance inst = this.buildInst(modules, false, testInstanceIndex);
        return this.classifier.distributionForInstance(inst);
    }

    @Override
    public double[] distributionForInstance(EnsembleModule[] modules, Instance testInstance) throws Exception {
        double[][] dists = new double[modules.length][];
        for (int m = 0; m < modules.length; ++m) {
            dists[m] = modules[m].getClassifier().distributionForInstance(testInstance);
            this.storeModuleTestResult(modules[m], dists[m]);
            for (int c = 0; c < this.numClasses; ++c) {
                double[] dArray = dists[m];
                int n = c;
                dArray[n] = dArray[n] * (modules[m].priorWeight * modules[m].posteriorWeights[c]);
            }
        }
        Instance inst = this.buildInst(dists, null);
        return this.classifier.distributionForInstance(inst);
    }

    @Override
    public String toString() {
        return super.toString() + "(" + this.classifier.getClass().getSimpleName() + ")";
    }
}

