/*
 * Decompiled with CFR 0.152.
 */
package timeseriesweka.classifiers.ensembles.elastic_ensemble;

import timeseriesweka.classifiers.ensembles.elastic_ensemble.Efficient1NN;
import timeseriesweka.elastic_distance_measures.ERPDistance;
import utilities.ClassifierTools;
import weka.classifiers.lazy.kNN;
import weka.core.Capabilities;
import weka.core.Instance;
import weka.core.Instances;

public class ERP1NN
extends Efficient1NN {
    private double g;
    private double bandSize;
    private double[] gValues;
    private double[] windowSizes;
    private boolean gAndWindowsRefreshed = false;

    public ERP1NN(double g, double bandSize) {
        this.g = g;
        this.bandSize = bandSize;
        this.gAndWindowsRefreshed = false;
        this.classifierIdentifier = "ERP_1NN";
        this.allowLoocv = false;
    }

    public ERP1NN() {
        this.g = 0.5;
        this.bandSize = 5.0;
        this.gAndWindowsRefreshed = false;
        this.classifierIdentifier = "ERP_1NN";
    }

    @Override
    public void buildClassifier(Instances train) throws Exception {
        super.buildClassifier(train);
        this.gAndWindowsRefreshed = false;
    }

    @Override
    public final double distance(Instance first, Instance second, double cutoff) {
        if (first.classIndex() != first.numAttributes() - 1 || second.classIndex() != second.numAttributes() - 1) {
            return new ERPDistance(this.g, this.bandSize).distance(first, second, cutoff);
        }
        int m = first.numAttributes() - 1;
        int n = second.numAttributes() - 1;
        double[] curr = new double[m];
        double[] prev = new double[m];
        int band = (int)Math.ceil((double)m * this.bandSize);
        double gValue = this.g;
        for (int i = 0; i < m; ++i) {
            int r;
            double[] temp = prev;
            prev = curr;
            curr = temp;
            int l = i - (band + 1);
            if (l < 0) {
                l = 0;
            }
            if ((r = i + (band + 1)) > m - 1) {
                r = m - 1;
            }
            for (int j = l; j <= r; ++j) {
                if (Math.abs(i - j) <= band) {
                    double val1 = first.value(i);
                    double val2 = gValue;
                    double diff = val1 - val2;
                    double d1 = Math.sqrt(diff * diff);
                    val1 = gValue;
                    val2 = second.value(j);
                    diff = val1 - val2;
                    double d2 = Math.sqrt(diff * diff);
                    val1 = first.value(i);
                    val2 = second.value(j);
                    diff = val1 - val2;
                    double d12 = Math.sqrt(diff * diff);
                    double dist1 = d1 * d1;
                    double dist2 = d2 * d2;
                    double dist12 = d12 * d12;
                    double cost = i + j != 0 ? (i == 0 || j != 0 && prev[j - 1] + dist12 > curr[j - 1] + dist2 && curr[j - 1] + dist2 < prev[j] + dist1 ? curr[j - 1] + dist2 : (j == 0 || i != 0 && prev[j - 1] + dist12 > prev[j] + dist1 && prev[j] + dist1 < curr[j - 1] + dist2 ? prev[j] + dist1 : prev[j - 1] + dist12)) : 0.0;
                    curr[j] = cost;
                    continue;
                }
                curr[j] = Double.POSITIVE_INFINITY;
            }
        }
        return Math.sqrt(curr[m - 1]);
    }

    @Override
    public Capabilities getCapabilities() {
        throw new UnsupportedOperationException("Not supported yet.");
    }

    public static void runComparison() throws Exception {
        double pred;
        int i;
        String tscProbDir = "C:/users/sjx07ngu/Dropbox/TSC Problems/";
        String datasetName = "SonyAiboRobotSurface1";
        double r = 0.1;
        Instances train = ClassifierTools.loadData(tscProbDir + datasetName + "/" + datasetName + "_TRAIN");
        Instances test = ClassifierTools.loadData(tscProbDir + datasetName + "/" + datasetName + "_TEST");
        kNN knn = new kNN();
        ERPDistance oldDtw = new ERPDistance(0.1, 0.1);
        knn.setDistanceFunction(oldDtw);
        knn.buildClassifier(train);
        ERP1NN dtwNew = new ERP1NN(0.1, 0.1);
        dtwNew.buildClassifier(train);
        int correctOld = 0;
        int correctNew = 0;
        long start = System.nanoTime();
        correctOld = 0;
        for (i = 0; i < test.numInstances(); ++i) {
            pred = knn.classifyInstance(test.instance(i));
            if (pred != test.instance(i).classValue()) continue;
            ++correctOld;
        }
        long end = System.nanoTime();
        long oldTime = end - start;
        start = System.nanoTime();
        correctNew = 0;
        for (i = 0; i < test.numInstances(); ++i) {
            pred = dtwNew.classifyInstance(test.instance(i));
            if (pred != test.instance(i).classValue()) continue;
            ++correctNew;
        }
        end = System.nanoTime();
        long newTime = end - start;
        System.out.println("Comparison of MSM: " + datasetName);
        System.out.println("==========================================");
        System.out.println("Old acc:    " + (double)correctOld / (double)test.numInstances());
        System.out.println("New acc:    " + (double)correctNew / (double)test.numInstances());
        System.out.println("Old timing: " + oldTime);
        System.out.println("New timing: " + newTime);
        System.out.println("Relative Performance: " + (double)newTime / (double)oldTime);
    }

    public static void main(String[] args) throws Exception {
        for (int i = 0; i < 10; ++i) {
            ERP1NN.runComparison();
        }
    }

    @Override
    public void setParamsFromParamId(Instances train, int paramId) {
        if (!this.gAndWindowsRefreshed) {
            double stdv = ERPDistance.stdv_p(train);
            this.windowSizes = ERPDistance.getInclusive10(0.0, 0.25);
            this.gValues = ERPDistance.getInclusive10(0.2 * stdv, stdv);
            this.gAndWindowsRefreshed = true;
        }
        this.g = this.gValues[paramId / 10];
        this.bandSize = this.windowSizes[paramId % 10];
    }

    @Override
    public String getParamInformationString() {
        return this.g + "," + this.bandSize;
    }
}

