/*
 * Decompiled with CFR 0.152.
 */
package timeseriesweka.classifiers.ensembles.voting;

import timeseriesweka.classifiers.ensembles.EnsembleModule;
import timeseriesweka.classifiers.ensembles.voting.ModuleVotingScheme;
import weka.core.Instance;

public class AverageOfConfidences
extends ModuleVotingScheme {
    public AverageOfConfidences() {
    }

    public AverageOfConfidences(int numClasses) {
        this.numClasses = numClasses;
    }

    @Override
    public void trainVotingScheme(EnsembleModule[] modules, int numClasses) {
        this.numClasses = numClasses;
    }

    @Override
    public double[] distributionForTrainInstance(EnsembleModule[] modules, int trainInstanceIndex) {
        double[] preds = new double[this.numClasses];
        for (int c = 0; c < this.numClasses; ++c) {
            double sum = 0.0;
            for (int m = 0; m < modules.length; ++m) {
                double[] p = modules[m].trainResults.getDistributionForInstance(trainInstanceIndex);
                sum += modules[m].priorWeight * modules[m].posteriorWeights[c] * p[c];
            }
            preds[c] = sum / (double)modules.length;
        }
        return this.normalise(preds);
    }

    @Override
    public double[] distributionForTestInstance(EnsembleModule[] modules, int testInstanceIndex) {
        double[] preds = new double[this.numClasses];
        for (int c = 0; c < this.numClasses; ++c) {
            double sum = 0.0;
            for (int m = 0; m < modules.length; ++m) {
                double[] p = modules[m].testResults.getDistributionForInstance(testInstanceIndex);
                sum += modules[m].priorWeight * modules[m].posteriorWeights[c] * p[c];
            }
            preds[c] = sum / (double)modules.length;
        }
        return this.normalise(preds);
    }

    @Override
    public double[] distributionForInstance(EnsembleModule[] modules, Instance testInstance) throws Exception {
        double[] preds = new double[this.numClasses];
        double[][] dists = new double[modules.length][];
        for (int m = 0; m < modules.length; ++m) {
            dists[m] = modules[m].getClassifier().distributionForInstance(testInstance);
            this.storeModuleTestResult(modules[m], dists[m]);
        }
        for (int c = 0; c < this.numClasses; ++c) {
            double sum = 0.0;
            for (int m = 0; m < modules.length; ++m) {
                sum += modules[m].priorWeight * modules[m].posteriorWeights[c] * dists[m][c];
            }
            preds[c] = sum / (double)modules.length;
        }
        return this.normalise(preds);
    }
}

