/*
 * Decompiled with CFR 0.152.
 */
package weka.classifiers.meta;

import java.io.BufferedInputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.ObjectInputStream;
import java.util.ArrayList;
import java.util.Enumeration;
import java.util.List;
import java.util.Random;
import java.util.Vector;
import weka.classifiers.Classifier;
import weka.classifiers.RandomizableMultipleClassifiersCombiner;
import weka.classifiers.rules.ZeroR;
import weka.core.Aggregateable;
import weka.core.Capabilities;
import weka.core.Environment;
import weka.core.EnvironmentHandler;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.RevisionUtils;
import weka.core.SelectedTag;
import weka.core.Tag;
import weka.core.TechnicalInformation;
import weka.core.TechnicalInformationHandler;
import weka.core.Utils;

public class Vote
extends RandomizableMultipleClassifiersCombiner
implements TechnicalInformationHandler,
EnvironmentHandler,
Aggregateable<Classifier> {
    static final long serialVersionUID = -637891196294399624L;
    public static final int AVERAGE_RULE = 1;
    public static final int PRODUCT_RULE = 2;
    public static final int MAJORITY_VOTING_RULE = 3;
    public static final int MIN_RULE = 4;
    public static final int MAX_RULE = 5;
    public static final int MEDIAN_RULE = 6;
    public static final Tag[] TAGS_RULES = new Tag[]{new Tag(1, "AVG", "Average of Probabilities"), new Tag(2, "PROD", "Product of Probabilities"), new Tag(3, "MAJ", "Majority Voting"), new Tag(4, "MIN", "Minimum Probability"), new Tag(5, "MAX", "Maximum Probability"), new Tag(6, "MED", "Median")};
    protected int m_CombinationRule = 1;
    protected Random m_Random;
    protected List<String> m_classifiersToLoad = new ArrayList<String>();
    protected List<Classifier> m_preBuiltClassifiers = new ArrayList<Classifier>();
    protected transient Environment m_env = Environment.getSystemWide();
    protected Instances m_structure;

    public String globalInfo() {
        return "Class for combining classifiers. Different combinations of probability estimates for classification are available.\n\nFor more information see:\n\n" + this.getTechnicalInformation().toString();
    }

    @Override
    public Enumeration listOptions() {
        Vector result = new Vector();
        Enumeration enm = super.listOptions();
        while (enm.hasMoreElements()) {
            result.addElement(enm.nextElement());
        }
        result.addElement(new Option("\tFull path to serialized classifier to include.\n\tMay be specified multiple times to include\n\tmultiple serialized classifiers. Note: it does\n\tnot make sense to use pre-built classifiers in\n\ta cross-validation.", "P", 1, "-P <path to serialized classifier>"));
        result.addElement(new Option("\tThe combination rule to use\n\t(default: AVG)", "R", 1, "-R " + Tag.toOptionList(TAGS_RULES)));
        return result.elements();
    }

    @Override
    public String[] getOptions() {
        int i;
        Vector<String> result = new Vector<String>();
        String[] options = super.getOptions();
        for (i = 0; i < options.length; ++i) {
            result.add(options[i]);
        }
        result.add("-R");
        result.add("" + this.getCombinationRule());
        for (i = 0; i < this.m_classifiersToLoad.size(); ++i) {
            result.add("-P");
            result.add(this.m_classifiersToLoad.get(i));
        }
        return result.toArray(new String[result.size()]);
    }

    @Override
    public void setOptions(String[] options) throws Exception {
        String loadString;
        String tmpStr = Utils.getOption('R', options);
        if (tmpStr.length() != 0) {
            this.setCombinationRule(new SelectedTag(tmpStr, TAGS_RULES));
        } else {
            this.setCombinationRule(new SelectedTag(1, TAGS_RULES));
        }
        this.m_classifiersToLoad.clear();
        while ((loadString = Utils.getOption('P', options)).length() != 0) {
            this.m_classifiersToLoad.add(loadString);
        }
        super.setOptions(options);
    }

    @Override
    public TechnicalInformation getTechnicalInformation() {
        TechnicalInformation result = new TechnicalInformation(TechnicalInformation.Type.BOOK);
        result.setValue(TechnicalInformation.Field.AUTHOR, "Ludmila I. Kuncheva");
        result.setValue(TechnicalInformation.Field.TITLE, "Combining Pattern Classifiers: Methods and Algorithms");
        result.setValue(TechnicalInformation.Field.YEAR, "2004");
        result.setValue(TechnicalInformation.Field.PUBLISHER, "John Wiley and Sons, Inc.");
        TechnicalInformation additional = result.add(TechnicalInformation.Type.ARTICLE);
        additional.setValue(TechnicalInformation.Field.AUTHOR, "J. Kittler and M. Hatef and Robert P.W. Duin and J. Matas");
        additional.setValue(TechnicalInformation.Field.YEAR, "1998");
        additional.setValue(TechnicalInformation.Field.TITLE, "On combining classifiers");
        additional.setValue(TechnicalInformation.Field.JOURNAL, "IEEE Transactions on Pattern Analysis and Machine Intelligence");
        additional.setValue(TechnicalInformation.Field.VOLUME, "20");
        additional.setValue(TechnicalInformation.Field.NUMBER, "3");
        additional.setValue(TechnicalInformation.Field.PAGES, "226-239");
        return result;
    }

    @Override
    public Capabilities getCapabilities() {
        Capabilities result = super.getCapabilities();
        if (this.m_preBuiltClassifiers.size() == 0 && this.m_classifiersToLoad.size() > 0) {
            try {
                this.loadClassifiers(null);
            }
            catch (Exception e) {
                e.printStackTrace();
            }
        }
        if (this.m_preBuiltClassifiers.size() > 0) {
            if (this.m_Classifiers.length == 0) {
                result = (Capabilities)this.m_preBuiltClassifiers.get(0).getCapabilities().clone();
            }
            for (int i = 1; i < this.m_preBuiltClassifiers.size(); ++i) {
                result.and(this.m_preBuiltClassifiers.get(i).getCapabilities());
            }
            for (Capabilities.Capability cap : Capabilities.Capability.values()) {
                result.enableDependency(cap);
            }
        }
        if (this.m_CombinationRule == 2 || this.m_CombinationRule == 3) {
            result.disableAllClasses();
            result.disableAllClassDependencies();
            result.enable(Capabilities.Capability.NOMINAL_CLASS);
            result.enableDependency(Capabilities.Capability.NOMINAL_CLASS);
        } else if (this.m_CombinationRule == 6) {
            result.disableAllClasses();
            result.disableAllClassDependencies();
            result.enable(Capabilities.Capability.NUMERIC_CLASS);
            result.enableDependency(Capabilities.Capability.NUMERIC_CLASS);
        }
        return result;
    }

    @Override
    public void buildClassifier(Instances data) throws Exception {
        Instances newData = new Instances(data);
        newData.deleteWithMissingClass();
        this.m_structure = new Instances(newData, 0);
        this.m_Random = new Random(this.getSeed());
        if (this.m_classifiersToLoad.size() > 0) {
            this.m_preBuiltClassifiers.clear();
            this.loadClassifiers(data);
            boolean index = false;
            if (this.m_Classifiers.length == 1 && this.m_Classifiers[0] instanceof ZeroR) {
                this.m_Classifiers = new Classifier[0];
            }
        }
        this.getCapabilities().testWithFail(data);
        for (int i = 0; i < this.m_Classifiers.length; ++i) {
            this.getClassifier(i).buildClassifier(newData);
        }
    }

    private void loadClassifiers(Instances data) throws Exception {
        for (String path : this.m_classifiersToLoad) {
            File toLoad;
            if (Environment.containsEnvVariables(path)) {
                try {
                    path = this.m_env.substitute(path);
                }
                catch (Exception exception) {
                    // empty catch block
                }
            }
            if (!(toLoad = new File(path)).isFile()) {
                throw new Exception("\"" + path + "\" does not seem to be a valid file!");
            }
            ObjectInputStream is = new ObjectInputStream(new BufferedInputStream(new FileInputStream(toLoad)));
            Object c = is.readObject();
            if (!(c instanceof Classifier)) {
                throw new Exception("\"" + path + "\" does not contain a classifier!");
            }
            Object header = null;
            header = is.readObject();
            if (header instanceof Instances && data != null && !data.equalHeaders((Instances)header)) {
                throw new Exception("\"" + path + "\" was trained with data that is " + "of a differnet structure than the incoming training data");
            }
            if (header == null) {
                System.out.println("[Vote] warning: no header instances for \"" + path + "\"");
            }
            this.addPreBuiltClassifier((Classifier)c);
        }
    }

    public void addPreBuiltClassifier(Classifier c) {
        this.m_preBuiltClassifiers.add(c);
    }

    public void removePreBuiltClassifier(Classifier c) {
        this.m_preBuiltClassifiers.remove(c);
    }

    @Override
    public double classifyInstance(Instance instance) throws Exception {
        double result;
        switch (this.m_CombinationRule) {
            case 1: 
            case 2: 
            case 3: 
            case 4: 
            case 5: {
                double[] dist = this.distributionForInstance(instance);
                if (instance.classAttribute().isNominal()) {
                    int index = Utils.maxIndex(dist);
                    if (dist[index] == 0.0) {
                        result = Utils.missingValue();
                        break;
                    }
                    result = index;
                    break;
                }
                if (instance.classAttribute().isNumeric()) {
                    result = dist[0];
                    break;
                }
                result = Utils.missingValue();
                break;
            }
            case 6: {
                result = this.classifyInstanceMedian(instance);
                break;
            }
            default: {
                throw new IllegalStateException("Unknown combination rule '" + this.m_CombinationRule + "'!");
            }
        }
        return result;
    }

    protected double classifyInstanceMedian(Instance instance) throws Exception {
        int i;
        double[] results = new double[this.m_Classifiers.length + this.m_preBuiltClassifiers.size()];
        for (i = 0; i < this.m_Classifiers.length; ++i) {
            results[i] = this.m_Classifiers[i].classifyInstance(instance);
        }
        for (i = 0; i < this.m_preBuiltClassifiers.size(); ++i) {
            results[i + this.m_Classifiers.length] = this.m_preBuiltClassifiers.get(i).classifyInstance(instance);
        }
        double result = results.length == 0 ? 0.0 : (results.length == 1 ? results[0] : Utils.kthSmallestValue(results, results.length / 2));
        return result;
    }

    @Override
    public double[] distributionForInstance(Instance instance) throws Exception {
        double[] result = new double[instance.numClasses()];
        switch (this.m_CombinationRule) {
            case 1: {
                result = this.distributionForInstanceAverage(instance);
                break;
            }
            case 2: {
                result = this.distributionForInstanceProduct(instance);
                break;
            }
            case 3: {
                result = this.distributionForInstanceMajorityVoting(instance);
                break;
            }
            case 4: {
                result = this.distributionForInstanceMin(instance);
                break;
            }
            case 5: {
                result = this.distributionForInstanceMax(instance);
                break;
            }
            case 6: {
                result[0] = this.classifyInstance(instance);
                break;
            }
            default: {
                throw new IllegalStateException("Unknown combination rule '" + this.m_CombinationRule + "'!");
            }
        }
        if (!instance.classAttribute().isNumeric() && Utils.sum(result) > 0.0) {
            Utils.normalize(result);
        }
        return result;
    }

    protected double[] distributionForInstanceAverage(Instance instance) throws Exception {
        int index;
        double[] probs = this.m_Classifiers.length > 0 ? this.getClassifier(0).distributionForInstance(instance) : this.m_preBuiltClassifiers.get(0).distributionForInstance(instance);
        probs = (double[])probs.clone();
        for (int i = 1; i < this.m_Classifiers.length; ++i) {
            double[] dist = this.getClassifier(i).distributionForInstance(instance);
            for (int j = 0; j < dist.length; ++j) {
                int n = j;
                probs[n] = probs[n] + dist[j];
            }
        }
        for (int i = index = this.m_Classifiers.length > 0 ? 0 : 1; i < this.m_preBuiltClassifiers.size(); ++i) {
            double[] dist = this.m_preBuiltClassifiers.get(i).distributionForInstance(instance);
            for (int j = 0; j < dist.length; ++j) {
                int n = j;
                probs[n] = probs[n] + dist[j];
            }
        }
        int j = 0;
        while (j < probs.length) {
            int n = j++;
            probs[n] = probs[n] / (double)(this.m_Classifiers.length + this.m_preBuiltClassifiers.size());
        }
        return probs;
    }

    protected double[] distributionForInstanceProduct(Instance instance) throws Exception {
        int index;
        double[] probs = this.m_Classifiers.length > 0 ? this.getClassifier(0).distributionForInstance(instance) : this.m_preBuiltClassifiers.get(0).distributionForInstance(instance);
        probs = (double[])probs.clone();
        for (int i = 1; i < this.m_Classifiers.length; ++i) {
            double[] dist = this.getClassifier(i).distributionForInstance(instance);
            for (int j = 0; j < dist.length; ++j) {
                int n = j;
                probs[n] = probs[n] * dist[j];
            }
        }
        for (int i = index = this.m_Classifiers.length > 0 ? 0 : 1; i < this.m_preBuiltClassifiers.size(); ++i) {
            double[] dist = this.m_preBuiltClassifiers.get(i).distributionForInstance(instance);
            for (int j = 0; j < dist.length; ++j) {
                int n = j;
                probs[n] = probs[n] * dist[j];
            }
        }
        return probs;
    }

    protected double[] distributionForInstanceMajorityVoting(Instance instance) throws Exception {
        int j;
        int maxIndex;
        int i;
        double[] probs = new double[instance.classAttribute().numValues()];
        double[] votes = new double[probs.length];
        for (i = 0; i < this.m_Classifiers.length; ++i) {
            probs = this.getClassifier(i).distributionForInstance(instance);
            maxIndex = 0;
            for (j = 0; j < probs.length; ++j) {
                if (!(probs[j] > probs[maxIndex])) continue;
                maxIndex = j;
            }
            for (j = 0; j < probs.length; ++j) {
                if (probs[j] != probs[maxIndex]) continue;
                int n = j;
                votes[n] = votes[n] + 1.0;
            }
        }
        for (i = 0; i < this.m_preBuiltClassifiers.size(); ++i) {
            probs = this.m_preBuiltClassifiers.get(i).distributionForInstance(instance);
            maxIndex = 0;
            for (j = 0; j < probs.length; ++j) {
                if (!(probs[j] > probs[maxIndex])) continue;
                maxIndex = j;
            }
            for (j = 0; j < probs.length; ++j) {
                if (probs[j] != probs[maxIndex]) continue;
                int n = j;
                votes[n] = votes[n] + 1.0;
            }
        }
        int tmpMajorityIndex = 0;
        for (int k = 1; k < votes.length; ++k) {
            if (!(votes[k] > votes[tmpMajorityIndex])) continue;
            tmpMajorityIndex = k;
        }
        Vector<Integer> majorityIndexes = new Vector<Integer>();
        for (int k = 0; k < votes.length; ++k) {
            if (votes[k] != votes[tmpMajorityIndex]) continue;
            majorityIndexes.add(k);
        }
        int majorityIndex = (Integer)majorityIndexes.get(this.m_Random.nextInt(majorityIndexes.size()));
        probs = new double[probs.length];
        probs[majorityIndex] = 1.0;
        return probs;
    }

    protected double[] distributionForInstanceMax(Instance instance) throws Exception {
        int index;
        double[] max = this.m_Classifiers.length > 0 ? this.getClassifier(0).distributionForInstance(instance) : this.m_preBuiltClassifiers.get(0).distributionForInstance(instance);
        max = (double[])max.clone();
        for (int i = 1; i < this.m_Classifiers.length; ++i) {
            double[] dist = this.getClassifier(i).distributionForInstance(instance);
            for (int j = 0; j < dist.length; ++j) {
                if (!(max[j] < dist[j])) continue;
                max[j] = dist[j];
            }
        }
        for (int i = index = this.m_Classifiers.length > 0 ? 0 : 1; i < this.m_preBuiltClassifiers.size(); ++i) {
            double[] dist = this.m_preBuiltClassifiers.get(i).distributionForInstance(instance);
            for (int j = 0; j < dist.length; ++j) {
                if (!(max[j] < dist[j])) continue;
                max[j] = dist[j];
            }
        }
        return max;
    }

    protected double[] distributionForInstanceMin(Instance instance) throws Exception {
        int index;
        double[] min = this.m_Classifiers.length > 0 ? this.getClassifier(0).distributionForInstance(instance) : this.m_preBuiltClassifiers.get(0).distributionForInstance(instance);
        min = (double[])min.clone();
        for (int i = 1; i < this.m_Classifiers.length; ++i) {
            double[] dist = this.getClassifier(i).distributionForInstance(instance);
            for (int j = 0; j < dist.length; ++j) {
                if (!(dist[j] < min[j])) continue;
                min[j] = dist[j];
            }
        }
        for (int i = index = this.m_Classifiers.length > 0 ? 0 : 1; i < this.m_preBuiltClassifiers.size(); ++i) {
            double[] dist = this.m_preBuiltClassifiers.get(i).distributionForInstance(instance);
            for (int j = 0; j < dist.length; ++j) {
                if (!(dist[j] < min[j])) continue;
                min[j] = dist[j];
            }
        }
        return min;
    }

    public String combinationRuleTipText() {
        return "The combination rule used.";
    }

    public SelectedTag getCombinationRule() {
        return new SelectedTag(this.m_CombinationRule, TAGS_RULES);
    }

    public void setCombinationRule(SelectedTag newRule) {
        if (newRule.getTags() == TAGS_RULES) {
            this.m_CombinationRule = newRule.getSelectedTag().getID();
        }
    }

    public String preBuiltClassifiersTipText() {
        return "The pre-built serialized classifiers to include. Multiple serialized classifiers can be included alongside those that are built from scratch when this classifier runs. Note that it does not make sense to include pre-built classifiers in a cross-validation since they are static and their models do not change from fold to fold.";
    }

    public void setPreBuiltClassifiers(File[] preBuilt) {
        this.m_classifiersToLoad.clear();
        if (preBuilt != null && preBuilt.length > 0) {
            for (int i = 0; i < preBuilt.length; ++i) {
                String path = preBuilt[i].toString();
                this.m_classifiersToLoad.add(path);
            }
        }
    }

    public File[] getPreBuiltClassifiers() {
        File[] result = new File[this.m_classifiersToLoad.size()];
        for (int i = 0; i < this.m_classifiersToLoad.size(); ++i) {
            result[i] = new File(this.m_classifiersToLoad.get(i));
        }
        return result;
    }

    public String toString() {
        if (this.m_Classifiers == null) {
            return "Vote: No model built yet.";
        }
        String result = "Vote combines";
        result = result + " the probability distributions of these base learners:\n";
        for (int i = 0; i < this.m_Classifiers.length; ++i) {
            result = result + '\t' + this.getClassifierSpec(i) + '\n';
        }
        for (Classifier c : this.m_preBuiltClassifiers) {
            result = result + "\t" + c.getClass().getName() + Utils.joinOptions(((OptionHandler)((Object)c)).getOptions()) + "\n";
        }
        result = result + "using the '";
        switch (this.m_CombinationRule) {
            case 1: {
                result = result + "Average of Probabilities";
                break;
            }
            case 2: {
                result = result + "Product of Probabilities";
                break;
            }
            case 3: {
                result = result + "Majority Voting";
                break;
            }
            case 4: {
                result = result + "Minimum Probability";
                break;
            }
            case 5: {
                result = result + "Maximum Probability";
                break;
            }
            case 6: {
                result = result + "Median Probability";
                break;
            }
            default: {
                throw new IllegalStateException("Unknown combination rule '" + this.m_CombinationRule + "'!");
            }
        }
        result = result + "' combination rule \n";
        return result;
    }

    @Override
    public String getRevision() {
        return RevisionUtils.extract("$Revision: 9785 $");
    }

    @Override
    public void setEnvironment(Environment env) {
        this.m_env = env;
    }

    @Override
    public Classifier aggregate(Classifier toAggregate) throws Exception {
        if (this.m_structure == null && this.m_Classifiers.length == 1 && this.m_Classifiers[0] instanceof ZeroR) {
            this.setClassifiers(new Classifier[0]);
        }
        this.addPreBuiltClassifier(toAggregate);
        return this;
    }

    @Override
    public void finalizeAggregation() throws Exception {
    }

    public static void main(String[] argv) {
        Vote.runClassifier(new Vote(), argv);
    }
}

