/*
 * Decompiled with CFR 0.152.
 */
package timeseriesweka.classifiers.ensembles.elastic_ensemble;

import timeseriesweka.classifiers.ensembles.elastic_ensemble.Efficient1NN;
import timeseriesweka.elastic_distance_measures.DTW;
import utilities.ClassifierTools;
import weka.classifiers.lazy.kNN;
import weka.core.Capabilities;
import weka.core.Instance;
import weka.core.Instances;

public class DTW1NN
extends Efficient1NN {
    private double r = 1.0;

    public DTW1NN(double r) {
        this.allowLoocv = false;
        this.r = r;
        this.classifierIdentifier = r != 1.0 ? "DTW_Rn_1NN" : "DTW_R1_1NN";
    }

    public DTW1NN() {
        this.r = 1.0;
        this.classifierIdentifier = "DTW_R1_1NN";
    }

    public void setWindow(double w) {
        this.r = w;
    }

    public void turnOffCV() {
        this.allowLoocv = false;
    }

    public void turnOnCV() {
        this.allowLoocv = true;
    }

    @Override
    public double[] loocv(Instances train) throws Exception {
        if (this.allowLoocv && this.classifierIdentifier.contains("R1")) {
            this.classifierIdentifier = this.classifierIdentifier.replace("R1", "Rn");
        }
        return super.loocv(train);
    }

    @Override
    public double[] loocv(Instances[] train) throws Exception {
        if (this.allowLoocv && this.classifierIdentifier.contains("R1")) {
            this.classifierIdentifier = this.classifierIdentifier.replace("R1", "Rn");
        }
        return super.loocv(train);
    }

    public final int getWindowSize(int n) {
        int w = (int)(this.r * (double)n);
        if (w < 1) {
            w = 1;
        } else if (w < n) {
            ++w;
        }
        return w;
    }

    @Override
    public final double distance(Instance first, Instance second, double cutoff) {
        int j;
        int end;
        int start;
        int i;
        if (first.classIndex() != first.numAttributes() - 1 || second.classIndex() != second.numAttributes() - 1) {
            DTW temp = new DTW();
            temp.setR(this.r);
            return temp.distance(first, second, cutoff);
        }
        int n = first.numAttributes() - 1;
        int m = second.numAttributes() - 1;
        int windowSize = this.getWindowSize(n);
        double[][] matrixD = new double[n][m];
        for (i = 0; i < n; ++i) {
            start = windowSize < i ? i - windowSize : 0;
            end = i + windowSize + 1 < m ? i + windowSize + 1 : m;
            for (j = start; j < end; ++j) {
                matrixD[i][j] = Double.MAX_VALUE;
            }
        }
        matrixD[0][0] = (first.value(0) - second.value(0)) * (first.value(0) - second.value(0));
        for (int j2 = 1; j2 < windowSize && j2 < m; ++j2) {
            matrixD[0][j2] = matrixD[0][j2 - 1] + (first.value(0) - second.value(j2)) * (first.value(0) - second.value(j2));
        }
        for (i = 1; i < windowSize && i < n; ++i) {
            matrixD[i][0] = matrixD[i - 1][0] + (first.value(i) - second.value(0)) * (first.value(i) - second.value(0));
        }
        for (i = 1; i < n; ++i) {
            boolean tooBig = true;
            start = windowSize < i ? i - windowSize + 1 : 1;
            end = i + windowSize < m ? i + windowSize : m;
            for (j = start; j < end; ++j) {
                double minDist = matrixD[i][j - 1];
                if (matrixD[i - 1][j] < minDist) {
                    minDist = matrixD[i - 1][j];
                }
                if (matrixD[i - 1][j - 1] < minDist) {
                    minDist = matrixD[i - 1][j - 1];
                }
                matrixD[i][j] = minDist + (first.value(i) - second.value(j)) * (first.value(i) - second.value(j));
                if (!tooBig || !(matrixD[i][j] < cutoff)) continue;
                tooBig = false;
            }
            if (!tooBig) continue;
            return Double.MAX_VALUE;
        }
        return matrixD[n - 1][m - 1];
    }

    @Override
    public double[] distributionForInstance(Instance instance) throws Exception {
        double[] res = new double[instance.numClasses()];
        int r = (int)this.classifyInstance(instance);
        res[r] = 1.0;
        return res;
    }

    @Override
    public Capabilities getCapabilities() {
        throw new UnsupportedOperationException("Not supported yet.");
    }

    public static void runComparison() throws Exception {
        double pred;
        int i;
        String tscProbDir = "C:/users/sjx07ngu/Dropbox/TSC Problems/";
        String datasetName = "GunPoint";
        double r = 0.1;
        Instances train = ClassifierTools.loadData(tscProbDir + datasetName + "/" + datasetName + "_TRAIN");
        Instances test = ClassifierTools.loadData(tscProbDir + datasetName + "/" + datasetName + "_TEST");
        kNN knn = new kNN();
        DTW oldDtw = new DTW();
        oldDtw.setR(r);
        knn.setDistanceFunction(oldDtw);
        knn.buildClassifier(train);
        DTW1NN dtwNew = new DTW1NN(r);
        dtwNew.buildClassifier(train);
        int correctOld = 0;
        int correctNew = 0;
        long start = System.nanoTime();
        correctOld = 0;
        for (i = 0; i < test.numInstances(); ++i) {
            pred = knn.classifyInstance(test.instance(i));
            if (pred != test.instance(i).classValue()) continue;
            ++correctOld;
        }
        long end = System.nanoTime();
        long oldTime = end - start;
        start = System.nanoTime();
        correctNew = 0;
        for (i = 0; i < test.numInstances(); ++i) {
            pred = dtwNew.classifyInstance(test.instance(i));
            if (pred != test.instance(i).classValue()) continue;
            ++correctNew;
        }
        end = System.nanoTime();
        long newTime = end - start;
        System.out.println("Comparison of MSM: " + datasetName);
        System.out.println("==========================================");
        System.out.println("Old acc:    " + (double)correctOld / (double)test.numInstances());
        System.out.println("New acc:    " + (double)correctNew / (double)test.numInstances());
        System.out.println("Old timing: " + oldTime);
        System.out.println("New timing: " + newTime);
        System.out.println("Relative Performance: " + (double)newTime / (double)oldTime);
    }

    public static void main(String[] args) throws Exception {
        for (int i = 0; i < 10; ++i) {
            DTW1NN.runComparison();
        }
    }

    @Override
    public void setParamsFromParamId(Instances train, int paramId) {
        if (this.allowLoocv) {
            if (this.classifierIdentifier.contains("R1")) {
                this.classifierIdentifier = this.classifierIdentifier.replace("R1", "Rn");
            }
        } else {
            throw new RuntimeException("Warning: trying to set parameters of a fixed window DTW");
        }
        this.r = (double)paramId / 100.0;
    }

    @Override
    public String getParamInformationString() {
        return this.r + "";
    }
}

