/*
 * Decompiled with CFR 0.152.
 */
package weka.classifiers.functions;

import java.util.Enumeration;
import java.util.Random;
import java.util.Vector;
import weka.classifiers.AbstractClassifier;
import weka.classifiers.UpdateableClassifier;
import weka.core.Capabilities;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.RevisionUtils;
import weka.core.TechnicalInformation;
import weka.core.TechnicalInformationHandler;
import weka.core.Utils;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.NominalToBinary;
import weka.filters.unsupervised.attribute.ReplaceMissingValues;

public class Winnow
extends AbstractClassifier
implements UpdateableClassifier,
TechnicalInformationHandler {
    static final long serialVersionUID = 3543770107994321324L;
    protected boolean m_Balanced;
    protected int m_numIterations = 1;
    protected double m_Alpha = 2.0;
    protected double m_Beta = 0.5;
    protected double m_Threshold = -1.0;
    protected int m_Seed = 1;
    protected int m_Mistakes;
    protected double m_defaultWeight = 2.0;
    private double[] m_predPosVector = null;
    private double[] m_predNegVector = null;
    private double m_actualThreshold;
    private Instances m_Train = null;
    private NominalToBinary m_NominalToBinary;
    private ReplaceMissingValues m_ReplaceMissingValues;

    public String globalInfo() {
        return "Implements Winnow and Balanced Winnow algorithms by Littlestone.\n\nFor more information, see\n\n" + this.getTechnicalInformation().toString() + "\n\n" + "Does classification for problems with nominal attributes " + "(which it converts into binary attributes).";
    }

    @Override
    public TechnicalInformation getTechnicalInformation() {
        TechnicalInformation result = new TechnicalInformation(TechnicalInformation.Type.ARTICLE);
        result.setValue(TechnicalInformation.Field.AUTHOR, "N. Littlestone");
        result.setValue(TechnicalInformation.Field.YEAR, "1988");
        result.setValue(TechnicalInformation.Field.TITLE, "Learning quickly when irrelevant attributes are abound: A new linear threshold algorithm");
        result.setValue(TechnicalInformation.Field.JOURNAL, "Machine Learning");
        result.setValue(TechnicalInformation.Field.VOLUME, "2");
        result.setValue(TechnicalInformation.Field.PAGES, "285-318");
        TechnicalInformation additional = result.add(TechnicalInformation.Type.TECHREPORT);
        additional.setValue(TechnicalInformation.Field.AUTHOR, "N. Littlestone");
        additional.setValue(TechnicalInformation.Field.YEAR, "1989");
        additional.setValue(TechnicalInformation.Field.TITLE, "Mistake bounds and logarithmic linear-threshold learning algorithms");
        additional.setValue(TechnicalInformation.Field.INSTITUTION, "University of California");
        additional.setValue(TechnicalInformation.Field.ADDRESS, "University of California, Santa Cruz");
        additional.setValue(TechnicalInformation.Field.NOTE, "Technical Report UCSC-CRL-89-11");
        return result;
    }

    @Override
    public Enumeration listOptions() {
        Vector<Option> newVector = new Vector<Option>(7);
        newVector.addElement(new Option("\tUse the baLanced version\n\t(default false)", "L", 0, "-L"));
        newVector.addElement(new Option("\tThe number of iterations to be performed.\n\t(default 1)", "I", 1, "-I <int>"));
        newVector.addElement(new Option("\tPromotion coefficient alpha.\n\t(default 2.0)", "A", 1, "-A <double>"));
        newVector.addElement(new Option("\tDemotion coefficient beta.\n\t(default 0.5)", "B", 1, "-B <double>"));
        newVector.addElement(new Option("\tPrediction threshold.\n\t(default -1.0 == number of attributes)", "H", 1, "-H <double>"));
        newVector.addElement(new Option("\tStarting weights.\n\t(default 2.0)", "W", 1, "-W <double>"));
        newVector.addElement(new Option("\tDefault random seed.\n\t(default 1)", "S", 1, "-S <int>"));
        return newVector.elements();
    }

    @Override
    public void setOptions(String[] options) throws Exception {
        String rString;
        String wString;
        String tString;
        String betaString;
        String alphaString;
        this.m_Balanced = Utils.getFlag('L', options);
        String iterationsString = Utils.getOption('I', options);
        if (iterationsString.length() != 0) {
            this.m_numIterations = Integer.parseInt(iterationsString);
        }
        if ((alphaString = Utils.getOption('A', options)).length() != 0) {
            this.m_Alpha = new Double(alphaString);
        }
        if ((betaString = Utils.getOption('B', options)).length() != 0) {
            this.m_Beta = new Double(betaString);
        }
        if ((tString = Utils.getOption('H', options)).length() != 0) {
            this.m_Threshold = new Double(tString);
        }
        if ((wString = Utils.getOption('W', options)).length() != 0) {
            this.m_defaultWeight = new Double(wString);
        }
        if ((rString = Utils.getOption('S', options)).length() != 0) {
            this.m_Seed = Integer.parseInt(rString);
        }
    }

    @Override
    public String[] getOptions() {
        String[] options = new String[20];
        int current = 0;
        if (this.m_Balanced) {
            options[current++] = "-L";
        }
        options[current++] = "-I";
        options[current++] = "" + this.m_numIterations;
        options[current++] = "-A";
        options[current++] = "" + this.m_Alpha;
        options[current++] = "-B";
        options[current++] = "" + this.m_Beta;
        options[current++] = "-H";
        options[current++] = "" + this.m_Threshold;
        options[current++] = "-W";
        options[current++] = "" + this.m_defaultWeight;
        options[current++] = "-S";
        options[current++] = "" + this.m_Seed;
        while (current < options.length) {
            options[current++] = "";
        }
        return options;
    }

    @Override
    public Capabilities getCapabilities() {
        Capabilities result = super.getCapabilities();
        result.disableAll();
        result.enable(Capabilities.Capability.NOMINAL_ATTRIBUTES);
        result.enable(Capabilities.Capability.MISSING_VALUES);
        result.enable(Capabilities.Capability.BINARY_CLASS);
        result.enable(Capabilities.Capability.MISSING_CLASS_VALUES);
        result.setMinimumNumberInstances(0);
        return result;
    }

    @Override
    public void buildClassifier(Instances insts) throws Exception {
        int it;
        int i;
        this.getCapabilities().testWithFail(insts);
        insts = new Instances(insts);
        insts.deleteWithMissingClass();
        this.m_Train = new Instances(insts);
        this.m_ReplaceMissingValues = new ReplaceMissingValues();
        this.m_ReplaceMissingValues.setInputFormat(this.m_Train);
        this.m_Train = Filter.useFilter(this.m_Train, this.m_ReplaceMissingValues);
        this.m_NominalToBinary = new NominalToBinary();
        this.m_NominalToBinary.setInputFormat(this.m_Train);
        this.m_Train = Filter.useFilter(this.m_Train, this.m_NominalToBinary);
        if (this.m_Seed != -1) {
            this.m_Train.randomize(new Random(this.m_Seed));
        }
        this.m_predPosVector = new double[this.m_Train.numAttributes()];
        if (this.m_Balanced) {
            this.m_predNegVector = new double[this.m_Train.numAttributes()];
        }
        for (i = 0; i < this.m_Train.numAttributes(); ++i) {
            this.m_predPosVector[i] = this.m_defaultWeight;
        }
        if (this.m_Balanced) {
            for (i = 0; i < this.m_Train.numAttributes(); ++i) {
                this.m_predNegVector[i] = this.m_defaultWeight;
            }
        }
        this.m_actualThreshold = this.m_Threshold < 0.0 ? (double)this.m_Train.numAttributes() - 1.0 : this.m_Threshold;
        this.m_Mistakes = 0;
        if (this.m_Balanced) {
            for (it = 0; it < this.m_numIterations; ++it) {
                for (int i2 = 0; i2 < this.m_Train.numInstances(); ++i2) {
                    this.actualUpdateClassifierBalanced(this.m_Train.instance(i2));
                }
            }
        } else {
            for (it = 0; it < this.m_numIterations; ++it) {
                for (int i3 = 0; i3 < this.m_Train.numInstances(); ++i3) {
                    this.actualUpdateClassifier(this.m_Train.instance(i3));
                }
            }
        }
    }

    @Override
    public void updateClassifier(Instance instance) throws Exception {
        this.m_ReplaceMissingValues.input(instance);
        this.m_ReplaceMissingValues.batchFinished();
        Instance filtered = this.m_ReplaceMissingValues.output();
        this.m_NominalToBinary.input(filtered);
        this.m_NominalToBinary.batchFinished();
        filtered = this.m_NominalToBinary.output();
        if (this.m_Balanced) {
            this.actualUpdateClassifierBalanced(filtered);
        } else {
            this.actualUpdateClassifier(filtered);
        }
    }

    private void actualUpdateClassifier(Instance inst) throws Exception {
        if (!inst.classIsMissing()) {
            double prediction = this.makePrediction(inst);
            if (prediction != inst.classValue()) {
                ++this.m_Mistakes;
                double posmultiplier = prediction == 0.0 ? this.m_Alpha : this.m_Beta;
                int n1 = inst.numValues();
                int classIndex = this.m_Train.classIndex();
                for (int l = 0; l < n1; ++l) {
                    if (inst.index(l) == classIndex || inst.valueSparse(l) != 1.0) continue;
                    int n = inst.index(l);
                    this.m_predPosVector[n] = this.m_predPosVector[n] * posmultiplier;
                }
            }
        } else {
            System.out.println("CLASS MISSING");
        }
    }

    private void actualUpdateClassifierBalanced(Instance inst) throws Exception {
        if (!inst.classIsMissing()) {
            double prediction = this.makePredictionBalanced(inst);
            if (prediction != inst.classValue()) {
                double negmultiplier;
                double posmultiplier;
                ++this.m_Mistakes;
                if (prediction == 0.0) {
                    posmultiplier = this.m_Alpha;
                    negmultiplier = this.m_Beta;
                } else {
                    posmultiplier = this.m_Beta;
                    negmultiplier = this.m_Alpha;
                }
                int n1 = inst.numValues();
                int classIndex = this.m_Train.classIndex();
                for (int l = 0; l < n1; ++l) {
                    if (inst.index(l) == classIndex || inst.valueSparse(l) != 1.0) continue;
                    int n = inst.index(l);
                    this.m_predPosVector[n] = this.m_predPosVector[n] * posmultiplier;
                    int n2 = inst.index(l);
                    this.m_predNegVector[n2] = this.m_predNegVector[n2] * negmultiplier;
                }
            }
        } else {
            System.out.println("CLASS MISSING");
        }
    }

    @Override
    public double classifyInstance(Instance inst) throws Exception {
        this.m_ReplaceMissingValues.input(inst);
        this.m_ReplaceMissingValues.batchFinished();
        Instance filtered = this.m_ReplaceMissingValues.output();
        this.m_NominalToBinary.input(filtered);
        this.m_NominalToBinary.batchFinished();
        filtered = this.m_NominalToBinary.output();
        if (this.m_Balanced) {
            return this.makePredictionBalanced(filtered);
        }
        return this.makePrediction(filtered);
    }

    private double makePrediction(Instance inst) throws Exception {
        double total = 0.0;
        int n1 = inst.numValues();
        int classIndex = this.m_Train.classIndex();
        for (int i = 0; i < n1; ++i) {
            if (inst.index(i) == classIndex || inst.valueSparse(i) != 1.0) continue;
            total += this.m_predPosVector[inst.index(i)];
        }
        if (total > this.m_actualThreshold) {
            return 1.0;
        }
        return 0.0;
    }

    private double makePredictionBalanced(Instance inst) throws Exception {
        double total = 0.0;
        int n1 = inst.numValues();
        int classIndex = this.m_Train.classIndex();
        for (int i = 0; i < n1; ++i) {
            if (inst.index(i) == classIndex || inst.valueSparse(i) != 1.0) continue;
            total += this.m_predPosVector[inst.index(i)] - this.m_predNegVector[inst.index(i)];
        }
        if (total > this.m_actualThreshold) {
            return 1.0;
        }
        return 0.0;
    }

    public String toString() {
        if (this.m_predPosVector == null) {
            return "Winnow: No model built yet.";
        }
        String result = "Winnow\n\nAttribute weights\n\n";
        int classIndex = this.m_Train.classIndex();
        if (!this.m_Balanced) {
            for (int i = 0; i < this.m_Train.numAttributes(); ++i) {
                if (i == classIndex) continue;
                result = result + "w" + i + " " + this.m_predPosVector[i] + "\n";
            }
        } else {
            for (int i = 0; i < this.m_Train.numAttributes(); ++i) {
                if (i == classIndex) continue;
                result = result + "w" + i + " p " + this.m_predPosVector[i];
                result = result + " n " + this.m_predNegVector[i];
                double wdiff = this.m_predPosVector[i] - this.m_predNegVector[i];
                result = result + " d " + wdiff + "\n";
            }
        }
        result = result + "\nCumulated mistake count: " + this.m_Mistakes + "\n\n";
        return result;
    }

    public String balancedTipText() {
        return "Whether to use the balanced version of the algorithm.";
    }

    public boolean getBalanced() {
        return this.m_Balanced;
    }

    public void setBalanced(boolean b) {
        this.m_Balanced = b;
    }

    public String alphaTipText() {
        return "Promotion coefficient alpha.";
    }

    public double getAlpha() {
        return this.m_Alpha;
    }

    public void setAlpha(double a) {
        this.m_Alpha = a;
    }

    public String betaTipText() {
        return "Demotion coefficient beta.";
    }

    public double getBeta() {
        return this.m_Beta;
    }

    public void setBeta(double b) {
        this.m_Beta = b;
    }

    public String thresholdTipText() {
        return "Prediction threshold (-1 means: set to number of attributes).";
    }

    public double getThreshold() {
        return this.m_Threshold;
    }

    public void setThreshold(double t) {
        this.m_Threshold = t;
    }

    public String defaultWeightTipText() {
        return "Initial value of weights/coefficients.";
    }

    public double getDefaultWeight() {
        return this.m_defaultWeight;
    }

    public void setDefaultWeight(double w) {
        this.m_defaultWeight = w;
    }

    public String numIterationsTipText() {
        return "The number of iterations to be performed.";
    }

    public int getNumIterations() {
        return this.m_numIterations;
    }

    public void setNumIterations(int v) {
        this.m_numIterations = v;
    }

    public String seedTipText() {
        return "Random number seed used for data shuffling (-1 means no randomization).";
    }

    public int getSeed() {
        return this.m_Seed;
    }

    public void setSeed(int v) {
        this.m_Seed = v;
    }

    @Override
    public String getRevision() {
        return RevisionUtils.extract("$Revision: 5523 $");
    }

    public static void main(String[] argv) {
        Winnow.runClassifier(new Winnow(), argv);
    }
}

