/*
 * Decompiled with CFR 0.152.
 */
package timeseriesweka.classifiers.ensembles.voting;

import timeseriesweka.classifiers.ensembles.EnsembleModule;
import utilities.ClassifierResults;
import utilities.DebugPrinting;
import weka.core.Instance;

public abstract class ModuleVotingScheme
implements DebugPrinting {
    protected int numClasses;
    public boolean needTrainPreds = false;

    public void trainVotingScheme(EnsembleModule[] modules, int numClasses) throws Exception {
        this.numClasses = numClasses;
    }

    public abstract double[] distributionForTrainInstance(EnsembleModule[] var1, int var2) throws Exception;

    public double classifyTrainInstance(EnsembleModule[] modules, int trainInstanceIndex) throws Exception {
        double[] dist = this.distributionForTrainInstance(modules, trainInstanceIndex);
        return this.indexOfMax(dist);
    }

    public abstract double[] distributionForTestInstance(EnsembleModule[] var1, int var2) throws Exception;

    public double classifyTestInstance(EnsembleModule[] modules, int testInstanceIndex) throws Exception {
        double[] dist = this.distributionForTestInstance(modules, testInstanceIndex);
        return this.indexOfMax(dist);
    }

    public abstract double[] distributionForInstance(EnsembleModule[] var1, Instance var2) throws Exception;

    public double classifyInstance(EnsembleModule[] modules, Instance testInstance) throws Exception {
        double[] dist = this.distributionForInstance(modules, testInstance);
        return this.indexOfMax(dist);
    }

    public double indexOfMax(double[] dist) {
        double max = dist[0];
        double maxInd = 0.0;
        for (int i = 1; i < dist.length; ++i) {
            if (!(dist[i] > max)) continue;
            max = dist[i];
            maxInd = i;
        }
        return maxInd;
    }

    public double[] normalise(double[] dist) {
        int i;
        double sum = dist[0];
        for (i = 1; i < dist.length; ++i) {
            sum += dist[i];
        }
        if (sum == 0.0) {
            for (i = 0; i < dist.length; ++i) {
                dist[i] = 1.0 / (double)dist.length;
            }
        } else {
            i = 0;
            while (i < dist.length) {
                int n = i++;
                dist[n] = dist[n] / sum;
            }
        }
        return dist;
    }

    public void storeModuleTestResult(EnsembleModule module, double[] dist) {
        if (module.testResults == null) {
            module.testResults = new ClassifierResults();
        }
        module.testResults.storeSingleResult(dist);
    }

    public String toString() {
        return this.getClass().getSimpleName();
    }
}

