/*
 * Decompiled with CFR 0.152.
 */
package timeseriesweka.classifiers.ensembles.elastic_ensemble;

import timeseriesweka.classifiers.ensembles.elastic_ensemble.Efficient1NN;
import timeseriesweka.elastic_distance_measures.WeightedDTW;
import utilities.ClassifierTools;
import weka.classifiers.lazy.kNN;
import weka.core.Capabilities;
import weka.core.Instance;
import weka.core.Instances;

public class WDTW1NN
extends Efficient1NN {
    private double g = 0.0;
    private double[] weightVector;
    private static final double WEIGHT_MAX = 1.0;
    private boolean refreshWeights = true;

    public WDTW1NN(double g) {
        this.g = g;
        this.classifierIdentifier = "WDTW_1NN";
        this.allowLoocv = false;
    }

    public WDTW1NN() {
        this.g = 0.0;
        this.classifierIdentifier = "WDTW_1NN";
    }

    private void initWeights(int seriesLength) {
        this.weightVector = new double[seriesLength];
        double halfLength = (double)seriesLength / 2.0;
        for (int i = 0; i < seriesLength; ++i) {
            this.weightVector[i] = 1.0 / (1.0 + Math.exp(-this.g * ((double)i - halfLength)));
        }
        this.refreshWeights = false;
    }

    @Override
    public final double distance(Instance first, Instance second, double cutoff) {
        int i;
        if (first.classIndex() != first.numAttributes() - 1 || second.classIndex() != second.numAttributes() - 1) {
            return new WeightedDTW(this.g).distance(first, second, cutoff);
        }
        int m = first.numAttributes() - 1;
        int n = second.numAttributes() - 1;
        if (this.refreshWeights) {
            this.initWeights(m);
        }
        double[][] distances = new double[m][n];
        distances[0][0] = this.weightVector[0] * (first.value(0) - second.value(0)) * (first.value(0) - second.value(0));
        if (distances[0][0] > cutoff) {
            return Double.MAX_VALUE;
        }
        for (i = 1; i < n; ++i) {
            distances[0][i] = distances[0][i - 1] + this.weightVector[i] * (first.value(0) - second.value(i)) * (first.value(0) - second.value(i));
        }
        for (i = 1; i < m; ++i) {
            distances[i][0] = distances[i - 1][0] + this.weightVector[i] * (first.value(i) - second.value(0)) * (first.value(i) - second.value(0));
        }
        for (int i2 = 1; i2 < m; ++i2) {
            boolean overflow = true;
            for (int j = 1; j < n; ++j) {
                double minDistance = Math.min(distances[i2][j - 1], Math.min(distances[i2 - 1][j], distances[i2 - 1][j - 1]));
                distances[i2][j] = minDistance + this.weightVector[Math.abs(i2 - j)] * (first.value(i2) - second.value(j)) * (first.value(i2) - second.value(j));
                if (!overflow || !(distances[i2][j] < cutoff)) continue;
                overflow = false;
            }
            if (!overflow) continue;
            return Double.MAX_VALUE;
        }
        return distances[m - 1][n - 1];
    }

    @Override
    public Capabilities getCapabilities() {
        throw new UnsupportedOperationException("Not supported yet.");
    }

    public static void runComparison() throws Exception {
        double pred;
        int i;
        String tscProbDir = "C:/users/sjx07ngu/Dropbox/TSC Problems/";
        String datasetName = "GunPoint";
        double r = 0.1;
        Instances train = ClassifierTools.loadData(tscProbDir + datasetName + "/" + datasetName + "_TRAIN");
        Instances test = ClassifierTools.loadData(tscProbDir + datasetName + "/" + datasetName + "_TEST");
        kNN knn = new kNN();
        WeightedDTW oldDtw = new WeightedDTW(r);
        knn.setDistanceFunction(oldDtw);
        knn.buildClassifier(train);
        WDTW1NN dtwNew = new WDTW1NN(r);
        dtwNew.buildClassifier(train);
        int correctOld = 0;
        int correctNew = 0;
        long start = System.nanoTime();
        correctOld = 0;
        for (i = 0; i < test.numInstances(); ++i) {
            pred = knn.classifyInstance(test.instance(i));
            if (pred != test.instance(i).classValue()) continue;
            ++correctOld;
        }
        long end = System.nanoTime();
        long oldTime = end - start;
        start = System.nanoTime();
        correctNew = 0;
        for (i = 0; i < test.numInstances(); ++i) {
            pred = dtwNew.classifyInstance(test.instance(i));
            if (pred != test.instance(i).classValue()) continue;
            ++correctNew;
        }
        end = System.nanoTime();
        long newTime = end - start;
        System.out.println("Comparison of MSM: " + datasetName);
        System.out.println("==========================================");
        System.out.println("Old acc:    " + (double)correctOld / (double)test.numInstances());
        System.out.println("New acc:    " + (double)correctNew / (double)test.numInstances());
        System.out.println("Old timing: " + oldTime);
        System.out.println("New timing: " + newTime);
        System.out.println("Relative Performance: " + (double)newTime / (double)oldTime);
    }

    public static void main(String[] args) throws Exception {
        Instances train = ClassifierTools.loadData("C:/users/sjx07ngu/dropbox/tsc problems/SonyAiboRobotSurface1/SonyAiboRobotSurface1_TRAIN");
        Instance one = train.firstInstance();
        Instance two = train.lastInstance();
        WDTW1NN wnn = new WDTW1NN();
        for (int paramId = 0; paramId < 100; ++paramId) {
            double g = (double)paramId / 100.0;
            WeightedDTW wdtw = new WeightedDTW(g);
            wnn.setParamsFromParamId(train, paramId);
            System.out.print(wdtw.distance(one, two) + "\t");
            System.out.println(wnn.distance(one, two, Double.MAX_VALUE));
        }
    }

    @Override
    public void setParamsFromParamId(Instances train, int paramId) {
        this.g = (double)paramId / 100.0;
        this.refreshWeights = true;
    }

    @Override
    public String getParamInformationString() {
        return this.g + ",";
    }

    public String toString() {
        return "this weight: " + this.g;
    }
}

