/*
 * Decompiled with CFR 0.152.
 */
package timeseriesweka.classifiers.ensembles.voting;

import timeseriesweka.classifiers.ensembles.EnsembleModule;
import timeseriesweka.classifiers.ensembles.voting.ModuleVotingScheme;
import weka.core.Instance;

public class NaiveBayesCombiner
extends ModuleVotingScheme {
    protected double[][][] postProbs;
    protected double[] priorClassProbs;
    protected boolean laplaceCorrection;

    public NaiveBayesCombiner() {
        this.needTrainPreds = true;
        this.laplaceCorrection = true;
    }

    public NaiveBayesCombiner(int numClasses) {
        this.numClasses = numClasses;
        this.needTrainPreds = true;
        this.laplaceCorrection = true;
    }

    public NaiveBayesCombiner(boolean laplaceCorrection) {
        this.needTrainPreds = true;
        this.laplaceCorrection = laplaceCorrection;
    }

    public NaiveBayesCombiner(boolean laplaceCorrection, int numClasses) {
        this.numClasses = numClasses;
        this.needTrainPreds = true;
        this.laplaceCorrection = laplaceCorrection;
    }

    @Override
    public void trainVotingScheme(EnsembleModule[] modules, int numClasses) throws Exception {
        this.numClasses = numClasses;
        this.postProbs = new double[numClasses][modules.length][numClasses];
        this.priorClassProbs = new double[numClasses];
        boolean correction = this.laplaceCorrection;
        for (int ac = 0; ac < numClasses; ++ac) {
            double numInClass = 0.0;
            for (int pc = 0; pc < numClasses; ++pc) {
                numInClass += modules[0].trainResults.confusionMatrix[ac][pc] + (double)correction;
            }
            this.priorClassProbs[ac] = numInClass / (double)modules[0].trainResults.numInstances();
            for (int m = 0; m < modules.length; ++m) {
                for (int pc = 0; pc < numClasses; ++pc) {
                    this.postProbs[ac][m][pc] = (modules[m].trainResults.confusionMatrix[ac][pc] + (double)correction) / numInClass;
                }
            }
        }
    }

    @Override
    public double[] distributionForTrainInstance(EnsembleModule[] modules, int trainInstanceIndex) {
        double[] dist = new double[this.numClasses];
        for (int ac = 0; ac < this.numClasses; ++ac) {
            dist[ac] = 1.0;
        }
        for (int m = 0; m < modules.length; ++m) {
            int pred = (int)modules[m].trainResults.getPredClassValue(trainInstanceIndex);
            for (int ac = 0; ac < this.numClasses; ++ac) {
                int n = ac;
                dist[n] = dist[n] * (this.postProbs[ac][m][pred] * modules[m].priorWeight * modules[m].posteriorWeights[pred]);
            }
        }
        for (int ac = 0; ac < this.numClasses; ++ac) {
            int n = ac;
            dist[n] = dist[n] / this.priorClassProbs[ac];
        }
        return this.normalise(dist);
    }

    @Override
    public double[] distributionForTestInstance(EnsembleModule[] modules, int testInstanceIndex) {
        double[] dist = new double[this.numClasses];
        for (int ac = 0; ac < this.numClasses; ++ac) {
            dist[ac] = 1.0;
        }
        for (int m = 0; m < modules.length; ++m) {
            int pred = (int)modules[m].testResults.getPredClassValue(testInstanceIndex);
            for (int ac = 0; ac < this.numClasses; ++ac) {
                int n = ac;
                dist[n] = dist[n] * (this.postProbs[ac][m][pred] * modules[m].priorWeight * modules[m].posteriorWeights[pred]);
            }
        }
        for (int ac = 0; ac < this.numClasses; ++ac) {
            int n = ac;
            dist[n] = dist[n] * this.priorClassProbs[ac];
        }
        return this.normalise(dist);
    }

    @Override
    public double[] distributionForInstance(EnsembleModule[] modules, Instance testInstance) throws Exception {
        double[] ensDist = new double[this.numClasses];
        for (int ac = 0; ac < this.numClasses; ++ac) {
            ensDist[ac] = 1.0;
        }
        for (int m = 0; m < modules.length; ++m) {
            double[] mdist = modules[m].getClassifier().distributionForInstance(testInstance);
            this.storeModuleTestResult(modules[m], mdist);
            int pred = (int)this.indexOfMax(mdist);
            for (int ac = 0; ac < this.numClasses; ++ac) {
                int n = ac;
                ensDist[n] = ensDist[n] * (this.postProbs[ac][m][pred] * modules[m].priorWeight * modules[m].posteriorWeights[pred]);
            }
        }
        for (int ac = 0; ac < this.numClasses; ++ac) {
            int n = ac;
            ensDist[n] = ensDist[n] / this.priorClassProbs[ac];
        }
        return this.normalise(ensDist);
    }
}

