/*
 * Decompiled with CFR 0.152.
 */
package timeseriesweka.classifiers;

import java.io.FileReader;
import timeseriesweka.elastic_distance_measures.DTW;
import weka.classifiers.lazy.kNN;
import weka.core.Instance;
import weka.core.Instances;

public class DTW_kNN
extends kNN {
    private boolean optimiseWindow = false;
    private double windowSize = 0.1;
    private double maxWindowSize = 1.0;
    private int incrementSize = 10;
    private Instances train;
    private int trainSize;
    private int bestWarp;
    DTW dtw = new DTW();

    public DTW_kNN() {
        this.dtw.setR(this.windowSize);
        this.setDistanceFunction(this.dtw);
        super.setKNN(1);
    }

    public void optimiseWindow(boolean b) {
        this.optimiseWindow = b;
    }

    public void setMaxR(double r) {
        this.maxWindowSize = r;
    }

    public DTW_kNN(int k) {
        super(k);
        this.dtw.setR(this.windowSize);
        this.optimiseWindow = true;
        this.setDistanceFunction(this.dtw);
    }

    @Override
    public void buildClassifier(Instances d) {
        this.dist.setInstances(d);
        this.train = d;
        this.trainSize = d.numInstances();
        if (this.optimiseWindow) {
            double maxR = 0.0;
            double maxAcc = 0.0;
            int dataLength = this.train.numAttributes() - 1;
            int max = (int)((double)dataLength * this.maxWindowSize);
            for (double i = 0.0; i < (double)max; i += (double)this.incrementSize) {
                this.dtw.setR(i / (double)dataLength);
                double acc = this.crossValidateAccuracy();
                if (!(acc > maxAcc)) continue;
                maxR = i / (double)dataLength;
                maxAcc = acc;
            }
            this.bestWarp = (int)(maxR * (double)dataLength);
            this.dtw.setR(maxR);
        }
        super.buildClassifier(d);
    }

    private double crossValidateAccuracy() {
        double a = 0.0;
        double d = 0.0;
        int nearest = 0;
        for (int i = 0; i < this.trainSize; ++i) {
            nearest = 0;
            double minDist = Double.MAX_VALUE;
            Instance inst = this.train.instance(i);
            for (int j = 0; j < this.trainSize; ++j) {
                if (i == j || !((d = this.dtw.distance(inst, this.train.instance(j), minDist)) < minDist)) continue;
                nearest = j;
                minDist = d;
            }
            if (inst.classValue() != this.train.instance(nearest).classValue()) continue;
            a += 1.0;
        }
        return a / (double)this.trainSize;
    }

    public static void main(String[] args) {
        DTW_kNN c = new DTW_kNN();
        String path = "C:\\Research\\Data\\Time Series Data\\Time Series Classification\\";
        Instances test = DTW_kNN.loadData(path + "Coffee\\Coffee_TEST.arff");
        Instances train = DTW_kNN.loadData(path + "Coffee\\Coffee_TRAIN.arff");
        train.setClassIndex(train.numAttributes() - 1);
        c.buildClassifier(train);
    }

    public static Instances loadData(String fileName) {
        Instances data = null;
        try {
            FileReader r = new FileReader(fileName);
            data = new Instances(r);
            data.setClassIndex(data.numAttributes() - 1);
        }
        catch (Exception e) {
            System.out.println(" Error =" + e + " in method loadData");
        }
        return data;
    }
}

