/*
 * Decompiled with CFR 0.152.
 */
package timeseriesweka.classifiers.ensembles.elastic_ensemble;

import timeseriesweka.classifiers.ensembles.elastic_ensemble.Efficient1NN;
import timeseriesweka.elastic_distance_measures.TWEDistance;
import utilities.ClassifierTools;
import weka.classifiers.lazy.kNN;
import weka.core.Capabilities;
import weka.core.Instance;
import weka.core.Instances;

public class TWE1NN
extends Efficient1NN {
    private static final double DEGREE = 2.0;
    double nu = 1.0;
    double lambda = 1.0;
    protected static double[] twe_nuParams = new double[]{1.0E-5, 1.0E-4, 5.0E-4, 0.001, 0.005, 0.01, 0.05, 0.1, 0.5, 1.0};
    protected static double[] twe_lamdaParams = new double[]{0.0, 0.011111111, 0.022222222, 0.033333333, 0.044444444, 0.055555556, 0.066666667, 0.077777778, 0.088888889, 0.1};

    public TWE1NN(double nu, double lambda) {
        this.nu = nu;
        this.lambda = lambda;
        this.classifierIdentifier = "TWE_1NN";
        this.allowLoocv = false;
    }

    public TWE1NN() {
        this.nu = 0.005;
        this.lambda = 0.5;
        this.classifierIdentifier = "TWE_1NN";
    }

    @Override
    public final double distance(Instance first, Instance second, double cutoff) {
        double dist;
        int k;
        int j;
        int i;
        int i2;
        if (first.classIndex() != first.numAttributes() - 1 || second.classIndex() != second.numAttributes() - 1) {
            return new TWEDistance(this.nu, this.lambda).distance(first, second, cutoff);
        }
        int m = first.numAttributes() - 1;
        int n = second.numAttributes() - 1;
        int dim = 1;
        double[][] ta = new double[m][dim];
        double[][] tb = new double[m][dim];
        double[] tsa = new double[m];
        double[] tsb = new double[n];
        for (i2 = 0; i2 < tsa.length; ++i2) {
            tsa[i2] = i2 + 1;
        }
        for (i2 = 0; i2 < tsb.length; ++i2) {
            tsb[i2] = i2 + 1;
        }
        int r = ta.length;
        int c = tb.length;
        for (i = 0; i < m; ++i) {
            ta[i][0] = first.value(i);
        }
        for (i = 0; i < n; ++i) {
            tb[i][0] = second.value(i);
        }
        double[][] D = new double[r + 1][c + 1];
        double[] Di1 = new double[r + 1];
        double[] Dj1 = new double[c + 1];
        for (j = 1; j <= c; ++j) {
            double distj1 = 0.0;
            for (k = 0; k < dim; ++k) {
                if (j > 1) {
                    distj1 += (tb[j - 2][k] - tb[j - 1][k]) * (tb[j - 2][k] - tb[j - 1][k]);
                    continue;
                }
                distj1 += tb[j - 1][k] * tb[j - 1][k];
            }
            Dj1[j] = distj1;
        }
        for (i = 1; i <= r; ++i) {
            double disti1 = 0.0;
            for (k = 0; k < dim; ++k) {
                if (i > 1) {
                    disti1 += (ta[i - 2][k] - ta[i - 1][k]) * (ta[i - 2][k] - ta[i - 1][k]);
                    continue;
                }
                disti1 += ta[i - 1][k] * ta[i - 1][k];
            }
            Di1[i] = disti1;
            for (j = 1; j <= c; ++j) {
                dist = 0.0;
                for (k = 0; k < dim; ++k) {
                    dist += (ta[i - 1][k] - tb[j - 1][k]) * (ta[i - 1][k] - tb[j - 1][k]);
                    if (i <= 1 || j <= 1) continue;
                    dist += (ta[i - 2][k] - tb[j - 2][k]) * (ta[i - 2][k] - tb[j - 2][k]);
                }
                D[i][j] = dist;
            }
        }
        D[0][0] = 0.0;
        for (i = 1; i <= r; ++i) {
            D[i][0] = D[i - 1][0] + Di1[i];
        }
        for (j = 1; j <= c; ++j) {
            D[0][j] = D[0][j - 1] + Dj1[j];
        }
        for (i = 1; i <= r; ++i) {
            for (j = 1; j <= c; ++j) {
                double dist0;
                double dmin;
                double htrans = Math.abs(tsa[i - 1] - tsb[j - 1]);
                if (j > 1 && i > 1) {
                    htrans += Math.abs(tsa[i - 2] - tsb[j - 2]);
                }
                if ((dmin = (dist0 = D[i - 1][j - 1] + this.nu * htrans + D[i][j])) > (dist = Di1[i] + D[i - 1][j] + this.lambda + this.nu * (htrans = i > 1 ? tsa[i - 1] - tsa[i - 2] : tsa[i - 1]))) {
                    dmin = dist;
                }
                if (dmin > (dist = Dj1[j] + D[i][j - 1] + this.lambda + this.nu * (htrans = j > 1 ? tsb[j - 1] - tsb[j - 2] : tsb[j - 1]))) {
                    dmin = dist;
                }
                D[i][j] = dmin;
            }
        }
        dist = D[r][c];
        return dist;
    }

    @Override
    public double[] distributionForInstance(Instance instance) throws Exception {
        throw new UnsupportedOperationException("Not supported yet.");
    }

    @Override
    public Capabilities getCapabilities() {
        throw new UnsupportedOperationException("Not supported yet.");
    }

    public static void runComparison() throws Exception {
        double pred;
        int i;
        String tscProbDir = "C:/users/sjx07ngu/Dropbox/TSC Problems/";
        String datasetName = "SonyAiboRobotSurface1";
        Instances train = ClassifierTools.loadData(tscProbDir + datasetName + "/" + datasetName + "_TRAIN");
        Instances test = ClassifierTools.loadData(tscProbDir + datasetName + "/" + datasetName + "_TEST");
        kNN knn = new kNN();
        TWEDistance oldDtw = new TWEDistance(0.001, 0.5);
        knn.setDistanceFunction(oldDtw);
        knn.buildClassifier(train);
        TWE1NN dtwNew = new TWE1NN(0.001, 0.5);
        dtwNew.buildClassifier(train);
        int correctOld = 0;
        int correctNew = 0;
        long start = System.nanoTime();
        correctOld = 0;
        for (i = 0; i < test.numInstances(); ++i) {
            pred = knn.classifyInstance(test.instance(i));
            if (pred != test.instance(i).classValue()) continue;
            ++correctOld;
        }
        long end = System.nanoTime();
        long oldTime = end - start;
        start = System.nanoTime();
        correctNew = 0;
        for (i = 0; i < test.numInstances(); ++i) {
            pred = dtwNew.classifyInstance(test.instance(i));
            if (pred != test.instance(i).classValue()) continue;
            ++correctNew;
        }
        end = System.nanoTime();
        long newTime = end - start;
        System.out.println("Comparison of MSM: " + datasetName);
        System.out.println("==========================================");
        System.out.println("Old acc:    " + (double)correctOld / (double)test.numInstances());
        System.out.println("New acc:    " + (double)correctNew / (double)test.numInstances());
        System.out.println("Old timing: " + oldTime);
        System.out.println("New timing: " + newTime);
        System.out.println("Relative Performance: " + (double)newTime / (double)oldTime);
    }

    public static void main(String[] args) throws Exception {
        for (int i = 0; i < 10; ++i) {
            TWE1NN.runComparison();
        }
    }

    @Override
    public void setParamsFromParamId(Instances train, int paramId) {
        this.nu = twe_nuParams[paramId / 10];
        this.lambda = twe_lamdaParams[paramId % 10];
    }

    @Override
    public String getParamInformationString() {
        return this.nu + "," + this.lambda;
    }
}

