/*
 * Decompiled with CFR 0.152.
 */
package timeseriesweka.classifiers.ensembles.voting;

import java.util.Arrays;
import timeseriesweka.classifiers.ensembles.EnsembleModule;
import timeseriesweka.classifiers.ensembles.voting.ModuleVotingScheme;
import weka.core.Instance;

public class MajorityVoteByConfidence
extends ModuleVotingScheme {
    public MajorityVoteByConfidence() {
    }

    public MajorityVoteByConfidence(int numClasses) {
        this.numClasses = numClasses;
    }

    @Override
    public void trainVotingScheme(EnsembleModule[] modules, int numClasses) {
        this.numClasses = numClasses;
    }

    @Override
    public double[] distributionForTrainInstance(EnsembleModule[] modules, int trainInstanceIndex) {
        int m;
        int pred;
        double[] preds = new double[this.numClasses];
        for (int m2 = 0; m2 < modules.length; ++m2) {
            int n = pred = (int)modules[m2].trainResults.getPredClassValue(trainInstanceIndex);
            preds[n] = preds[n] + modules[m2].priorWeight * modules[m2].posteriorWeights[pred] * modules[m2].trainResults.getDistributionForInstance(trainInstanceIndex)[pred];
        }
        double[] unweightedPreds = new double[this.numClasses];
        for (m = 0; m < modules.length; ++m) {
            int n = pred = (int)modules[m].trainResults.getPredClassValue(trainInstanceIndex);
            unweightedPreds[n] = unweightedPreds[n] + 1.0;
        }
        for (m = 0; m < modules.length; ++m) {
            this.printlnDebug(modules[m].getModuleName() + " distForInst:  " + Arrays.toString(modules[m].trainResults.getDistributionForInstance(trainInstanceIndex)));
            this.printlnDebug(modules[m].getModuleName() + " priorweights: " + modules[m].priorWeight);
            this.printlnDebug(modules[m].getModuleName() + " postweights:  " + Arrays.toString(modules[m].posteriorWeights));
            this.printlnDebug(modules[m].getModuleName() + " voteweight:   " + modules[m].priorWeight * modules[m].posteriorWeights[(int)modules[m].trainResults.getPredClassValue(trainInstanceIndex)] * modules[m].trainResults.getDistributionForInstance(trainInstanceIndex)[(int)modules[m].trainResults.getPredClassValue(trainInstanceIndex)]);
        }
        this.printlnDebug("Ensemble Votes: " + Arrays.toString(unweightedPreds));
        this.printlnDebug("Ensemble Dist:  " + Arrays.toString(preds));
        this.printlnDebug("Normed:         " + Arrays.toString(this.normalise(preds)));
        this.printlnDebug("");
        return this.normalise(preds);
    }

    @Override
    public double[] distributionForTestInstance(EnsembleModule[] modules, int testInstanceIndex) {
        double[] preds = new double[this.numClasses];
        for (int m = 0; m < modules.length; ++m) {
            int pred;
            int n = pred = (int)modules[m].testResults.getPredClassValue(testInstanceIndex);
            preds[n] = preds[n] + modules[m].priorWeight * modules[m].posteriorWeights[pred] * modules[m].testResults.getDistributionForInstance(testInstanceIndex)[pred];
        }
        return this.normalise(preds);
    }

    @Override
    public double[] distributionForInstance(EnsembleModule[] modules, Instance testInstance) throws Exception {
        double[] preds = new double[this.numClasses];
        for (int m = 0; m < modules.length; ++m) {
            int pred;
            double[] dist = modules[m].getClassifier().distributionForInstance(testInstance);
            this.storeModuleTestResult(modules[m], dist);
            int n = pred = (int)this.indexOfMax(dist);
            preds[n] = preds[n] + modules[m].priorWeight * modules[m].posteriorWeights[pred] * dist[pred];
        }
        return this.normalise(preds);
    }
}

