/*
 * Decompiled with CFR 0.152.
 */
package weka.classifiers.lazy;

import java.util.Enumeration;
import java.util.Vector;
import weka.classifiers.Classifier;
import weka.classifiers.SingleClassifierEnhancer;
import weka.classifiers.UpdateableClassifier;
import weka.classifiers.rules.ZeroR;
import weka.classifiers.trees.DecisionStump;
import weka.core.Capabilities;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.RevisionUtils;
import weka.core.TechnicalInformation;
import weka.core.TechnicalInformationHandler;
import weka.core.Utils;
import weka.core.WeightedInstancesHandler;
import weka.core.neighboursearch.LinearNNSearch;
import weka.core.neighboursearch.NearestNeighbourSearch;

public class LWL
extends SingleClassifierEnhancer
implements UpdateableClassifier,
WeightedInstancesHandler,
TechnicalInformationHandler {
    static final long serialVersionUID = 1979797405383665815L;
    protected Instances m_Train;
    protected int m_kNN = -1;
    protected int m_WeightKernel = 0;
    protected boolean m_UseAllK = true;
    protected NearestNeighbourSearch m_NNSearch = new LinearNNSearch();
    public static final int LINEAR = 0;
    public static final int EPANECHNIKOV = 1;
    public static final int TRICUBE = 2;
    public static final int INVERSE = 3;
    public static final int GAUSS = 4;
    public static final int CONSTANT = 5;
    protected Classifier m_ZeroR;

    public String globalInfo() {
        return "Locally weighted learning. Uses an instance-based algorithm to assign instance weights which are then used by a specified WeightedInstancesHandler.\nCan do classification (e.g. using naive Bayes) or regression (e.g. using linear regression).\n\nFor more info, see\n\n" + this.getTechnicalInformation().toString();
    }

    @Override
    public TechnicalInformation getTechnicalInformation() {
        TechnicalInformation result = new TechnicalInformation(TechnicalInformation.Type.INPROCEEDINGS);
        result.setValue(TechnicalInformation.Field.AUTHOR, "Eibe Frank and Mark Hall and Bernhard Pfahringer");
        result.setValue(TechnicalInformation.Field.YEAR, "2003");
        result.setValue(TechnicalInformation.Field.TITLE, "Locally Weighted Naive Bayes");
        result.setValue(TechnicalInformation.Field.BOOKTITLE, "19th Conference in Uncertainty in Artificial Intelligence");
        result.setValue(TechnicalInformation.Field.PAGES, "249-256");
        result.setValue(TechnicalInformation.Field.PUBLISHER, "Morgan Kaufmann");
        TechnicalInformation additional = result.add(TechnicalInformation.Type.ARTICLE);
        additional.setValue(TechnicalInformation.Field.AUTHOR, "C. Atkeson and A. Moore and S. Schaal");
        additional.setValue(TechnicalInformation.Field.YEAR, "1996");
        additional.setValue(TechnicalInformation.Field.TITLE, "Locally weighted learning");
        additional.setValue(TechnicalInformation.Field.JOURNAL, "AI Review");
        return result;
    }

    public LWL() {
        this.m_Classifier = new DecisionStump();
    }

    @Override
    protected String defaultClassifierString() {
        return "weka.classifiers.trees.DecisionStump";
    }

    public Enumeration enumerateMeasures() {
        return this.m_NNSearch.enumerateMeasures();
    }

    public double getMeasure(String additionalMeasureName) {
        return this.m_NNSearch.getMeasure(additionalMeasureName);
    }

    @Override
    public Enumeration listOptions() {
        Vector<Option> newVector = new Vector<Option>(3);
        newVector.addElement(new Option("\tThe nearest neighbour search algorithm to use (default: weka.core.neighboursearch.LinearNNSearch).\n", "A", 0, "-A"));
        newVector.addElement(new Option("\tSet the number of neighbours used to set the kernel bandwidth.\n\t(default all)", "K", 1, "-K <number of neighbours>"));
        newVector.addElement(new Option("\tSet the weighting kernel shape to use. 0=Linear, 1=Epanechnikov,\n\t2=Tricube, 3=Inverse, 4=Gaussian.\n\t(default 0 = Linear)", "U", 1, "-U <number of weighting method>"));
        Enumeration enu = super.listOptions();
        while (enu.hasMoreElements()) {
            newVector.addElement((Option)enu.nextElement());
        }
        return newVector.elements();
    }

    @Override
    public void setOptions(String[] options) throws Exception {
        String knnString = Utils.getOption('K', options);
        if (knnString.length() != 0) {
            this.setKNN(Integer.parseInt(knnString));
        } else {
            this.setKNN(-1);
        }
        String weightString = Utils.getOption('U', options);
        if (weightString.length() != 0) {
            this.setWeightingKernel(Integer.parseInt(weightString));
        } else {
            this.setWeightingKernel(0);
        }
        String nnSearchClass = Utils.getOption('A', options);
        if (nnSearchClass.length() != 0) {
            String[] nnSearchClassSpec = Utils.splitOptions(nnSearchClass);
            if (nnSearchClassSpec.length == 0) {
                throw new Exception("Invalid NearestNeighbourSearch algorithm specification string.");
            }
            String className = nnSearchClassSpec[0];
            nnSearchClassSpec[0] = "";
            this.setNearestNeighbourSearchAlgorithm((NearestNeighbourSearch)Utils.forName(NearestNeighbourSearch.class, className, nnSearchClassSpec));
        } else {
            this.setNearestNeighbourSearchAlgorithm(new LinearNNSearch());
        }
        super.setOptions(options);
    }

    @Override
    public String[] getOptions() {
        String[] superOptions = super.getOptions();
        String[] options = new String[superOptions.length + 6];
        int current = 0;
        options[current++] = "-U";
        options[current++] = "" + this.getWeightingKernel();
        if (this.getKNN() == 0 && this.m_UseAllK) {
            options[current++] = "-K";
            options[current++] = "-1";
        } else {
            options[current++] = "-K";
            options[current++] = "" + this.getKNN();
        }
        options[current++] = "-A";
        options[current++] = this.m_NNSearch.getClass().getName() + " " + Utils.joinOptions(this.m_NNSearch.getOptions());
        System.arraycopy(superOptions, 0, options, current, superOptions.length);
        return options;
    }

    public String KNNTipText() {
        return "How many neighbours are used to determine the width of the weighting function (<= 0 means all neighbours).";
    }

    public void setKNN(int knn) {
        this.m_kNN = knn;
        if (knn <= 0) {
            this.m_kNN = 0;
            this.m_UseAllK = true;
        } else {
            this.m_UseAllK = false;
        }
    }

    public int getKNN() {
        return this.m_kNN;
    }

    public String weightingKernelTipText() {
        return "Determines weighting function. [0 = Linear, 1 = Epnechnikov,2 = Tricube, 3 = Inverse, 4 = Gaussian and 5 = Constant. (default 0 = Linear)].";
    }

    public void setWeightingKernel(int kernel) {
        if (kernel != 0 && kernel != 1 && kernel != 2 && kernel != 3 && kernel != 4 && kernel != 5) {
            return;
        }
        this.m_WeightKernel = kernel;
    }

    public int getWeightingKernel() {
        return this.m_WeightKernel;
    }

    public String nearestNeighbourSearchAlgorithmTipText() {
        return "The nearest neighbour search algorithm to use (Default: LinearNN).";
    }

    public NearestNeighbourSearch getNearestNeighbourSearchAlgorithm() {
        return this.m_NNSearch;
    }

    public void setNearestNeighbourSearchAlgorithm(NearestNeighbourSearch nearestNeighbourSearchAlgorithm) {
        this.m_NNSearch = nearestNeighbourSearchAlgorithm;
    }

    @Override
    public Capabilities getCapabilities() {
        Capabilities result = this.m_Classifier != null ? this.m_Classifier.getCapabilities() : super.getCapabilities();
        result.setMinimumNumberInstances(0);
        for (Capabilities.Capability cap : Capabilities.Capability.values()) {
            result.enableDependency(cap);
        }
        return result;
    }

    @Override
    public void buildClassifier(Instances instances) throws Exception {
        if (!(this.m_Classifier instanceof WeightedInstancesHandler)) {
            throw new IllegalArgumentException("Classifier must be a WeightedInstancesHandler!");
        }
        this.getCapabilities().testWithFail(instances);
        instances = new Instances(instances);
        instances.deleteWithMissingClass();
        if (instances.numAttributes() == 1) {
            System.err.println("Cannot build model (only class attribute present in data!), using ZeroR model instead!");
            this.m_ZeroR = new ZeroR();
            this.m_ZeroR.buildClassifier(instances);
            return;
        }
        this.m_ZeroR = null;
        this.m_Train = new Instances(instances, 0, instances.numInstances());
        this.m_NNSearch.setInstances(this.m_Train);
    }

    @Override
    public void updateClassifier(Instance instance) throws Exception {
        if (this.m_Train == null) {
            throw new Exception("No training instance structure set!");
        }
        if (!this.m_Train.equalHeaders(instance.dataset())) {
            throw new Exception("Incompatible instance types\n" + this.m_Train.equalHeadersMsg(instance.dataset()));
        }
        if (!instance.classIsMissing()) {
            this.m_NNSearch.update(instance);
            this.m_Train.add(instance);
        }
    }

    @Override
    public double[] distributionForInstance(Instance instance) throws Exception {
        int i;
        int i2;
        double bandwidth;
        if (this.m_ZeroR != null) {
            return this.m_ZeroR.distributionForInstance(instance);
        }
        if (this.m_Train.numInstances() == 0) {
            throw new Exception("No training instances!");
        }
        this.m_NNSearch.addInstanceInfo(instance);
        int k = this.m_Train.numInstances();
        if (!this.m_UseAllK && this.m_kNN < k) {
            k = this.m_kNN;
        }
        Instances neighbours = this.m_NNSearch.kNearestNeighbours(instance, k);
        double[] distances = this.m_NNSearch.getDistances();
        if (this.m_Debug) {
            System.out.println("Test Instance: " + instance);
            System.out.println("For " + k + " kept " + neighbours.numInstances() + " out of " + this.m_Train.numInstances() + " instances.");
        }
        if (k > distances.length) {
            k = distances.length;
        }
        if (this.m_Debug) {
            System.out.println("Instance Distances");
            for (int i3 = 0; i3 < distances.length; ++i3) {
                System.out.println("" + distances[i3]);
            }
        }
        if ((bandwidth = distances[k - 1]) <= 0.0) {
            for (i2 = 0; i2 < distances.length; ++i2) {
                distances[i2] = 1.0;
            }
        } else {
            for (i2 = 0; i2 < distances.length; ++i2) {
                distances[i2] = distances[i2] / bandwidth;
            }
        }
        block11: for (i2 = 0; i2 < distances.length; ++i2) {
            switch (this.m_WeightKernel) {
                case 0: {
                    distances[i2] = 1.0001 - distances[i2];
                    continue block11;
                }
                case 1: {
                    distances[i2] = 0.75 * (1.0001 - distances[i2] * distances[i2]);
                    continue block11;
                }
                case 2: {
                    distances[i2] = Math.pow(1.0001 - Math.pow(distances[i2], 3.0), 3.0);
                    continue block11;
                }
                case 5: {
                    distances[i2] = 1.0;
                    continue block11;
                }
                case 3: {
                    distances[i2] = 1.0 / (1.0 + distances[i2]);
                    continue block11;
                }
                case 4: {
                    distances[i2] = Math.exp(-distances[i2] * distances[i2]);
                }
            }
        }
        if (this.m_Debug) {
            System.out.println("Instance Weights");
            for (i2 = 0; i2 < distances.length; ++i2) {
                System.out.println("" + distances[i2]);
            }
        }
        double sumOfWeights = 0.0;
        double newSumOfWeights = 0.0;
        for (i = 0; i < distances.length; ++i) {
            double weight = distances[i];
            Instance inst = neighbours.instance(i);
            sumOfWeights += inst.weight();
            newSumOfWeights += inst.weight() * weight;
            inst.setWeight(inst.weight() * weight);
        }
        for (i = 0; i < neighbours.numInstances(); ++i) {
            Instance inst = neighbours.instance(i);
            inst.setWeight(inst.weight() * sumOfWeights / newSumOfWeights);
        }
        this.m_Classifier.buildClassifier(neighbours);
        if (this.m_Debug) {
            System.out.println("Classifying test instance: " + instance);
            System.out.println("Built base classifier:\n" + this.m_Classifier.toString());
        }
        return this.m_Classifier.distributionForInstance(instance);
    }

    public String toString() {
        if (this.m_ZeroR != null) {
            StringBuffer buf = new StringBuffer();
            buf.append(this.getClass().getName().replaceAll(".*\\.", "") + "\n");
            buf.append(this.getClass().getName().replaceAll(".*\\.", "").replaceAll(".", "=") + "\n\n");
            buf.append("Warning: No model could be built, hence ZeroR model is used:\n\n");
            buf.append(this.m_ZeroR.toString());
            return buf.toString();
        }
        if (this.m_Train == null) {
            return "Locally weighted learning: No model built yet.";
        }
        String result = "Locally weighted learning\n===========================\n";
        result = result + "Using classifier: " + this.m_Classifier.getClass().getName() + "\n";
        switch (this.m_WeightKernel) {
            case 0: {
                result = result + "Using linear weighting kernels\n";
                break;
            }
            case 1: {
                result = result + "Using epanechnikov weighting kernels\n";
                break;
            }
            case 2: {
                result = result + "Using tricube weighting kernels\n";
                break;
            }
            case 3: {
                result = result + "Using inverse-distance weighting kernels\n";
                break;
            }
            case 4: {
                result = result + "Using gaussian weighting kernels\n";
                break;
            }
            case 5: {
                result = result + "Using constant weighting kernels\n";
            }
        }
        result = result + "Using " + (this.m_UseAllK ? "all" : "" + this.m_kNN) + " neighbours";
        return result;
    }

    @Override
    public String getRevision() {
        return RevisionUtils.extract("$Revision: 8034 $");
    }

    public static void main(String[] argv) {
        LWL.runClassifier(new LWL(), argv);
    }
}

