/*
 * Decompiled with CFR 0.152.
 */
package multivariate_timeseriesweka.ensembles;

import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import timeseriesweka.classifiers.ensembles.EnsembleModule;
import timeseriesweka.classifiers.ensembles.voting.MajorityVote;
import timeseriesweka.classifiers.ensembles.voting.ModuleVotingScheme;
import timeseriesweka.classifiers.ensembles.weightings.EqualWeighting;
import timeseriesweka.classifiers.ensembles.weightings.ModuleWeightingScheme;
import utilities.MultivariateInstanceTools;
import weka.classifiers.AbstractClassifier;
import weka.classifiers.Classifier;
import weka.classifiers.RandomizableIteratedSingleClassifierEnhancer;
import weka.core.Instance;
import weka.core.Instances;

public class IndependentDimensionEnsemble
extends AbstractClassifier {
    protected ModuleWeightingScheme weightingScheme = new EqualWeighting();
    protected ModuleVotingScheme votingScheme = new MajorityVote();
    protected EnsembleModule[] modules;
    long seed;
    int numClasses;
    int numChannels;
    Instances train;
    Instances[] channels;
    Classifier[] classifiers;
    String[] classifierNames;
    Classifier original_model;
    double[] priorWeights;
    Instances[] convertedTest = null;

    public IndependentDimensionEnsemble(Classifier cla) {
        this.original_model = cla;
    }

    public void setSeed(long sd) {
        this.seed = sd;
        if (this.original_model instanceof RandomizableIteratedSingleClassifierEnhancer) {
            RandomizableIteratedSingleClassifierEnhancer r = (RandomizableIteratedSingleClassifierEnhancer)this.original_model;
            r.setSeed((int)this.seed);
        } else {
            Method[] methods;
            for (Method method : methods = this.original_model.getClass().getMethods()) {
                Class<?>[] paras = method.getParameterTypes();
                String name = method.getName().toLowerCase();
                if (!name.contains("random") && !name.contains("seed") || paras.length != 1 || paras[0] != Integer.TYPE && paras[0] != Long.TYPE) continue;
                try {
                    if (paras[0] == Integer.TYPE) {
                        method.invoke((Object)this.original_model, (int)this.seed);
                        continue;
                    }
                    method.invoke((Object)this.original_model, this.seed);
                }
                catch (IllegalAccessException | IllegalArgumentException | InvocationTargetException ex) {
                    System.out.println(ex);
                    System.out.println("Tried to set the seed method name: " + method.getName());
                }
            }
        }
    }

    public void setPriorWeights(double[] weights) {
        this.priorWeights = weights;
    }

    protected void initialiseModules() throws Exception {
        this.classifiers = AbstractClassifier.makeCopies(this.original_model, this.numChannels);
        this.classifierNames = new String[this.numChannels];
        this.modules = new EnsembleModule[this.numChannels];
        for (int m = 0; m < this.numChannels; ++m) {
            this.classifierNames[m] = this.classifiers[m].getClass().getSimpleName() + "_" + m;
            this.modules[m] = new EnsembleModule(this.classifierNames[m], this.classifiers[m], "");
            if (this.priorWeights == null) continue;
            this.modules[m].priorWeight = this.priorWeights[m];
        }
        this.weightingScheme.defineWeightings(this.modules, this.numClasses);
        this.votingScheme.trainVotingScheme(this.modules, this.numClasses);
    }

    @Override
    public void buildClassifier(Instances data) throws Exception {
        this.train = data;
        this.numClasses = data.numClasses();
        this.channels = MultivariateInstanceTools.splitMultivariateInstances(data);
        this.numChannels = this.channels.length;
        this.initialiseModules();
        for (int i = 0; i < this.numChannels; ++i) {
            Instances channel = this.channels[i];
            this.modules[i].getClassifier().buildClassifier(channel);
        }
    }

    @Override
    public double[] distributionForInstance(Instance instance) throws Exception {
        double[] dist = this.distributionForInstance(this.votingScheme, this.modules, MultivariateInstanceTools.splitMultivariateInstanceWithClassVal(instance));
        return dist;
    }

    private double[] distributionForInstance(ModuleVotingScheme vs, EnsembleModule[] modules, Instance[] testInstance) throws Exception {
        double[] preds = new double[this.numClasses];
        for (int m = 0; m < this.numChannels; ++m) {
            int pred;
            double[] dist = modules[m].getClassifier().distributionForInstance(testInstance[m]);
            vs.storeModuleTestResult(modules[m], dist);
            int n = pred = (int)vs.indexOfMax(dist);
            preds[n] = preds[n] + modules[m].priorWeight * modules[m].posteriorWeights[pred];
        }
        return vs.normalise(preds);
    }
}

