/*
 * Decompiled with CFR 0.152.
 */
package weka.classifiers.lazy;

import development.DataSets;
import java.text.DecimalFormat;
import utilities.ClassifierTools;
import weka.classifiers.lazy.AttributeFilterBridge;
import weka.classifiers.lazy.IBk;
import weka.core.DistanceFunction;
import weka.core.EuclideanDistance;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.NormalizableDistance;
import weka.core.neighboursearch.NearestNeighbourSearch;

public class kNN
extends IBk {
    protected DistanceFunction dist;
    double[][] distMatrix;
    boolean storeDistance;
    boolean filterAttributes = false;
    double propAtts = 0.5;
    int nosAtts = 0;
    AttributeFilterBridge af;

    public kNN() {
        super.setKNN(1);
        EuclideanDistance ed = new EuclideanDistance();
        ed.setDontNormalize(true);
        this.setDistanceFunction(ed);
    }

    public kNN(int k) {
        super(k);
        EuclideanDistance ed = new EuclideanDistance();
        ed.setDontNormalize(true);
        this.setDistanceFunction(ed);
    }

    public kNN(DistanceFunction df) {
        this.setDistanceFunction(df);
    }

    public final void setDistanceFunction(DistanceFunction df) {
        this.dist = df;
        NearestNeighbourSearch s = super.getNearestNeighbourSearchAlgorithm();
        try {
            s.setDistanceFunction(df);
        }
        catch (Exception e) {
            System.err.println(" Exception thrown setting distance function =" + e + " in " + this);
            e.printStackTrace();
            System.exit(0);
        }
    }

    public double distance(Instance first, Instance second) {
        return this.dist.distance(first, second);
    }

    public void normalise(boolean v) {
        if (this.dist instanceof NormalizableDistance) {
            ((NormalizableDistance)this.dist).setDontNormalize(!v);
        } else {
            System.out.println(" Not normalisable");
        }
    }

    @Override
    public void buildClassifier(Instances d) {
        Instances d2 = d;
        if (this.filterAttributes) {
            d2 = this.filter(d);
        }
        this.dist.setInstances(d2);
        try {
            super.buildClassifier(d2);
        }
        catch (Exception e) {
            System.out.println("Exception thrown in kNN build Classifier = " + e);
            e.printStackTrace();
            System.exit(0);
        }
    }

    @Override
    public double[] distributionForInstance(Instance instance) throws Exception {
        if (this.af != null) {
            Instance newInst = this.af.filterInstance(instance);
            return super.distributionForInstance(newInst);
        }
        return super.distributionForInstance(instance);
    }

    public double[] getPredictions(Instances test) {
        double[] pred = new double[test.numInstances()];
        try {
            for (int i = 0; i < test.numInstances(); ++i) {
                pred[i] = this.classifyInstance(test.instance(i));
                System.out.println("Pred = " + pred[i]);
            }
        }
        catch (Exception e) {
            System.out.println("Exception thrown in getPredictions in kNN = " + e);
            e.printStackTrace();
            System.exit(0);
        }
        return pred;
    }

    public static void test1NNvsIB1(boolean norm) {
        System.out.println("FIRST BASIC SANITY TEST FOR THIS WRAPPER");
        System.out.print("Compare 1-NN with IB1, normalisation turned");
        String str = norm ? " on" : " off";
        System.out.println(str);
        System.out.println("Compare on the UCI data sets");
        System.out.print("If normalisation is off, then there may be differences");
        kNN knn = new kNN(1);
        IBk ib1 = new IBk(1);
        knn.normalise(norm);
        int diff = 0;
        DecimalFormat df = new DecimalFormat("####.###");
        for (String s : DataSets.uciFileNames) {
            Instances train = ClassifierTools.loadData(DataSets.uciPath + s + "\\" + s + "-train");
            Instances test = ClassifierTools.loadData(DataSets.uciPath + s + "\\" + s + "-test");
            try {
                knn.buildClassifier(train);
                ib1.buildClassifier(train);
                double a1 = ClassifierTools.accuracy(test, knn);
                double a2 = ClassifierTools.accuracy(test, ib1);
                if (a1 == a2) continue;
                ++diff;
                System.out.println(s + ": 1-NN =" + df.format(a1) + " ib1=" + df.format(a2));
            }
            catch (Exception e) {
                System.out.println(" Exception builing a classifier");
                System.exit(0);
            }
        }
        System.out.println("Total problems =" + DataSets.uciFileNames.length + " different on " + diff);
    }

    public static void testkNNvsIBk(boolean norm, boolean crossValidate) {
        System.out.println("FIRST BASIC SANITY TEST FOR THIS WRAPPER");
        System.out.print("Compare 1-NN with IB1, normalisation turned");
        String str = norm ? " on" : " off";
        System.out.println(str);
        System.out.print("Cross validation turned");
        str = crossValidate ? " on" : " off";
        System.out.println(str);
        System.out.println("Compare on the UCI data sets");
        System.out.print("If normalisation is off, then there may be differences");
        kNN knn = new kNN(100);
        IBk ibk = new IBk(100);
        knn.normalise(norm);
        knn.setCrossValidate(crossValidate);
        ibk.setCrossValidate(crossValidate);
        int diff = 0;
        DecimalFormat df = new DecimalFormat("####.###");
        for (String s : DataSets.uciFileNames) {
            Instances train = ClassifierTools.loadData(DataSets.uciPath + s + "\\" + s + "-train");
            Instances test = ClassifierTools.loadData(DataSets.uciPath + s + "\\" + s + "-test");
            try {
                knn.buildClassifier(train);
                ibk.buildClassifier(train);
                double a1 = ClassifierTools.accuracy(test, knn);
                double a2 = ClassifierTools.accuracy(test, ibk);
                if (a1 == a2) continue;
                ++diff;
                System.out.println(s + ": 1-NN =" + df.format(a1) + " ibk=" + df.format(a2));
            }
            catch (Exception e) {
                System.out.println(" Exception builing a classifier");
                System.exit(0);
            }
        }
        System.out.println("Total problems =" + DataSets.uciFileNames.length + " different on " + diff);
    }

    public static void main(String[] args) {
        kNN.testkNNvsIBk(true, true);
    }

    public void setFilterAttributes(boolean f) {
        this.filterAttributes = f;
    }

    public void setProportion(double f) {
        this.propAtts = f;
    }

    public void setNumber(int n) {
        this.nosAtts = n;
    }

    private Instances filter(Instances d) {
        this.af = new AttributeFilterBridge(d);
        this.af.setProportionToKeep(this.propAtts);
        Instances d2 = this.af.filter();
        return d2;
    }
}

