/*
 * Decompiled with CFR 0.152.
 */
package timeseriesweka.classifiers;

import java.text.DecimalFormat;
import timeseriesweka.elastic_distance_measures.DTW_DistanceBasic;
import utilities.ClassifierResults;
import utilities.ClassifierTools;
import utilities.InstanceTools;
import utilities.SaveParameterInfo;
import weka.classifiers.lazy.kNN;
import weka.core.DenseInstance;
import weka.core.EuclideanDistance;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.TechnicalInformation;
import weka.core.neighboursearch.PerformanceStats;
import weka.filters.SimpleBatchFilter;

public class DD_DTW
extends kNN
implements SaveParameterInfo {
    protected ClassifierResults res = new ClassifierResults();
    public static final String DATA_DIR = "C:/Temp/Dropbox/TSC Problems/";
    public static final double[] ALPHAS = new double[]{1.0, 1.01, 1.02, 1.03, 1.04, 1.05, 1.06, 1.07, 1.08, 1.09, 1.1, 1.11, 1.12, 1.13, 1.14, 1.15, 1.16, 1.17, 1.18, 1.19, 1.2, 1.21, 1.22, 1.23, 1.24, 1.25, 1.26, 1.27, 1.28, 1.29, 1.3, 1.31, 1.32, 1.33, 1.34, 1.35, 1.36, 1.37, 1.38, 1.39, 1.4, 1.41, 1.42, 1.43, 1.44, 1.45, 1.46, 1.47, 1.48, 1.49, 1.5, 1.51, 1.52, 1.53, 1.54, 1.55, 1.56, 1.57};
    public static final String[] GORECKI_DATASETS = new String[]{"fiftywords", "Adiac", "Beef", "CBF", "Coffee", "FaceAll", "FaceFour", "fish", "GunPoint", "Lightning2", "Lightning7", "OliveOil", "OSULeaf", "SwedishLeaf", "SyntheticControl", "Trace", "TwoPatterns", "wafer", "yoga"};
    protected GoreckiDerivativesEuclideanDistance distanceFunction;
    protected boolean paramsSet;
    protected boolean sampleForCV = false;
    protected double prop;

    @Override
    public TechnicalInformation getTechnicalInformation() {
        TechnicalInformation result = new TechnicalInformation(TechnicalInformation.Type.ARTICLE);
        result.setValue(TechnicalInformation.Field.AUTHOR, "T. G\u00f3recki and M. \u0141uczak");
        result.setValue(TechnicalInformation.Field.TITLE, "Using derivatives in time series classification");
        result.setValue(TechnicalInformation.Field.JOURNAL, "Data Mining and Knowledge Discovery");
        result.setValue(TechnicalInformation.Field.VOLUME, "26");
        result.setValue(TechnicalInformation.Field.NUMBER, "2");
        result.setValue(TechnicalInformation.Field.PAGES, "310-331");
        result.setValue(TechnicalInformation.Field.YEAR, "2015");
        return result;
    }

    public void sampleForCV(boolean b, double p) {
        this.sampleForCV = b;
        this.prop = p;
    }

    public DD_DTW() {
        this.distanceFunction = new GoreckiDerivativesDTW();
        this.paramsSet = false;
    }

    public DD_DTW(DistanceType distType) {
        this.distanceFunction = distType == DistanceType.EUCLIDEAN ? new GoreckiDerivativesEuclideanDistance() : new GoreckiDerivativesDTW();
        this.paramsSet = false;
    }

    public void setAandB(double a, double b) {
        this.distanceFunction.a = a;
        this.distanceFunction.b = b;
        this.paramsSet = true;
    }

    @Override
    public void buildClassifier(Instances train) {
        this.res.buildTime = System.currentTimeMillis();
        if (!this.paramsSet) {
            this.distanceFunction.crossValidateForAandB(train);
            this.paramsSet = true;
        }
        this.setDistanceFunction(this.distanceFunction);
        super.buildClassifier(train);
        this.res.buildTime = System.currentTimeMillis() - this.res.buildTime;
    }

    @Override
    public String getParameters() {
        return "BuildTime," + this.res.buildTime + ",a," + this.distanceFunction.a + ",b," + this.distanceFunction.b;
    }

    public static void recreateResultsTable() throws Exception {
        DD_DTW.recreateResultsTable(0);
    }

    public static void recreateResultsTable(int seed) throws Exception {
        String[] datasets = GORECKI_DATASETS;
        String dataDir = DATA_DIR;
        DecimalFormat df = new DecimalFormat("##.##");
        GoreckiDerivativeFilter derFilter = new GoreckiDerivativeFilter();
        StringBuilder st = new StringBuilder();
        System.out.println("Dataset,ED,DED,DD_ED,DTW,DDTW,DD_DTW");
        for (String dataset : datasets) {
            System.out.print(dataset + ",");
            Instances train = ClassifierTools.loadData(dataDir + dataset + "/" + dataset + "_TRAIN");
            Instances test = ClassifierTools.loadData(dataDir + dataset + "/" + dataset + "_TEST");
            if (seed != 0) {
                Instances[] temp = InstanceTools.resampleTrainAndTestInstances(train, test, seed);
                train = temp[0];
                test = temp[1];
            }
            Instances dTrain = derFilter.process(train);
            Instances dTest = derFilter.process(test);
            EuclideanDistance ed = new EuclideanDistance();
            ed.setDontNormalize(true);
            kNN knn = new kNN(ed);
            int correct = DD_DTW.getCorrect(knn, train, test);
            double acc = (double)correct / (double)test.numInstances();
            double err = (1.0 - acc) * 100.0;
            System.out.print(df.format(err) + ",");
            ed = new EuclideanDistance();
            knn = new kNN(ed);
            correct = DD_DTW.getCorrect(knn, dTrain, dTest);
            acc = (double)correct / (double)test.numInstances();
            err = (1.0 - acc) * 100.0;
            System.out.print(df.format(err) + ",");
            DD_DTW dd_ed = new DD_DTW(DistanceType.EUCLIDEAN);
            correct = DD_DTW.getCorrect(dd_ed, train, test);
            acc = (double)correct / (double)test.numInstances();
            err = (1.0 - acc) * 100.0;
            System.out.print(df.format(err) + ",");
            DTW_DistanceBasic dtw = new DTW_DistanceBasic();
            knn = new kNN(dtw);
            correct = DD_DTW.getCorrect(knn, train, test);
            acc = (double)correct / (double)test.numInstances();
            err = (1.0 - acc) * 100.0;
            System.out.print(df.format(err) + ",");
            DTW_DistanceBasic dDtw = new DTW_DistanceBasic();
            knn = new kNN(dDtw);
            correct = DD_DTW.getCorrect(knn, dTrain, dTest);
            acc = (double)correct / (double)test.numInstances();
            err = (1.0 - acc) * 100.0;
            System.out.print(df.format(err) + ",");
            DD_DTW dd_dtw = new DD_DTW(DistanceType.DTW);
            correct = DD_DTW.getCorrect(dd_dtw, train, test);
            acc = (double)correct / (double)test.numInstances();
            err = (1.0 - acc) * 100.0;
            System.out.println(df.format(err));
        }
    }

    public static void main(String[] args) {
        int option = 1;
        try {
            if (option == 1) {
                String dataName = "ItalyPowerDemand";
                Instances train = ClassifierTools.loadData(DATA_DIR + dataName + "/" + dataName + "_TRAIN");
                Instances test = ClassifierTools.loadData(DATA_DIR + dataName + "/" + dataName + "_TEST");
                DD_DTW nndw = new DD_DTW(DistanceType.DTW);
                nndw.buildClassifier(train);
                int correct = 0;
                for (int i = 0; i < test.numInstances(); ++i) {
                    if (nndw.classifyInstance(test.instance(i)) != test.instance(i).classValue()) continue;
                    ++correct;
                }
                System.out.println(dataName + ":\t" + new DecimalFormat("#.###").format((double)correct / (double)test.numInstances() * 100.0) + "%");
            } else if (option == 2) {
                DD_DTW.recreateResultsTable();
            }
        }
        catch (Exception e) {
            e.printStackTrace();
        }
    }

    protected static int getCorrect(kNN knn, Instances train, Instances test) throws Exception {
        knn.buildClassifier(train);
        int correct = 0;
        for (int i = 0; i < test.numInstances(); ++i) {
            if (test.instance(i).classValue() != knn.classifyInstance(test.instance(i))) continue;
            ++correct;
        }
        return correct;
    }

    private static class GoreckiDerivativeFilter
    extends SimpleBatchFilter {
        private GoreckiDerivativeFilter() {
        }

        @Override
        public String globalInfo() {
            throw new UnsupportedOperationException("Not supported yet.");
        }

        @Override
        protected Instances determineOutputFormat(Instances inputFormat) throws Exception {
            Instances output = new Instances(inputFormat, 0);
            output.deleteAttributeAt(0);
            output.setRelationName("goreckiDerivative_" + output.relationName());
            for (int a = 0; a < output.numAttributes() - 1; ++a) {
                output.renameAttribute(a, "derivative_" + a);
            }
            return output;
        }

        @Override
        public Instances process(Instances instances) throws Exception {
            Instances output = this.determineOutputFormat(instances);
            for (int i = 0; i < instances.numInstances(); ++i) {
                Instance thisInstance = instances.get(i);
                DenseInstance toAdd = new DenseInstance(output.numAttributes());
                for (int a = 1; a < instances.numAttributes() - 1; ++a) {
                    double der = thisInstance.value(a) - thisInstance.value(a - 1);
                    toAdd.setValue(a - 1, der);
                }
                toAdd.setValue(output.numAttributes() - 1, thisInstance.classValue());
                output.add(toAdd);
            }
            return output;
        }
    }

    public static class GoreckiDerivativesDTW
    extends GoreckiDerivativesEuclideanDistance {
        public GoreckiDerivativesDTW() {
        }

        public GoreckiDerivativesDTW(Instances train) {
            super(train);
        }

        public GoreckiDerivativesDTW(double alpha) {
            super(alpha);
        }

        public GoreckiDerivativesDTW(double a, double b) {
            super(a, b);
        }

        @Override
        public double distance(Instance one, Instance two) {
            return this.distance(one, two, Double.MAX_VALUE);
        }

        @Override
        public double distance(Instance one, Instance two, double cutoff, PerformanceStats stats) {
            return this.distance(one, two, cutoff);
        }

        @Override
        public double distance(Instance first, Instance second, double cutoff) {
            double[] distances = this.getNonScaledDistances(first, second);
            return this.a * distances[0] + this.b * distances[1];
        }

        @Override
        public double[] getNonScaledDistances(Instance first, Instance second) {
            double dist = 0.0;
            double derDist = 0.0;
            DTW_DistanceBasic dtw = new DTW_DistanceBasic();
            boolean classPenalty = false;
            if (first.classIndex() > 0) {
                classPenalty = true;
            }
            GoreckiDerivativeFilter filter = new GoreckiDerivativeFilter();
            Instances temp = new Instances(first.dataset(), 0);
            temp.add(first);
            temp.add(second);
            try {
                temp = filter.process(temp);
            }
            catch (Exception e) {
                e.printStackTrace();
                return null;
            }
            dist = dtw.distance(first, second);
            derDist = dtw.distance(temp.get(0), temp.get(1), Double.MAX_VALUE);
            return new double[]{Math.sqrt(dist), Math.sqrt(derDist)};
        }
    }

    public static class GoreckiDerivativesEuclideanDistance
    extends EuclideanDistance {
        protected double alpha;
        protected double a;
        protected double b;
        public boolean sampleTrain = true;

        public GoreckiDerivativesEuclideanDistance() {
            this.a = 1.0;
            this.b = 0.0;
            this.alpha = -1.0;
        }

        public GoreckiDerivativesEuclideanDistance(Instances train) {
            this.crossValidateForAandB(train);
        }

        public GoreckiDerivativesEuclideanDistance(double alpha) {
            this.alpha = alpha;
            this.a = Math.cos(alpha);
            this.b = Math.sin(alpha);
        }

        public GoreckiDerivativesEuclideanDistance(double a, double b) {
            this.alpha = this.alpha;
            this.a = Math.cos(this.alpha);
            this.b = Math.sin(this.alpha);
        }

        @Override
        public double distance(Instance one, Instance two) {
            return this.distance(one, two, Double.MAX_VALUE);
        }

        @Override
        public double distance(Instance one, Instance two, double cutoff, PerformanceStats stats) {
            return this.distance(one, two, cutoff);
        }

        @Override
        public double distance(Instance first, Instance second, double cutoff) {
            double dist = 0.0;
            double dirDist = 0.0;
            int classPenalty = 0;
            if (first.classIndex() > 0) {
                classPenalty = 1;
            }
            for (int i = 0; i < first.numAttributes() - classPenalty; ++i) {
                dist += (first.value(i) - second.value(i)) * (first.value(i) - second.value(i));
                if (i >= first.numAttributes() - classPenalty - 1) continue;
                double firstDir = first.value(i + 1) - first.value(i);
                double secondDir = second.value(i + 1) - second.value(i);
                dirDist += (firstDir - secondDir) * (firstDir - secondDir);
            }
            return this.a * Math.sqrt(dist) + this.b * Math.sqrt(dirDist);
        }

        public double[] getNonScaledDistances(Instance first, Instance second) {
            double dist = 0.0;
            double dirDist = 0.0;
            int classPenalty = 0;
            if (first.classIndex() > 0) {
                classPenalty = 1;
            }
            for (int i = 0; i < first.numAttributes() - classPenalty; ++i) {
                dist += (first.value(i) - second.value(i)) * (first.value(i) - second.value(i));
                if (i >= first.numAttributes() - classPenalty - 1) continue;
                double firstDir = first.value(i + 1) - first.value(i);
                double secondDir = second.value(i + 1) - second.value(i);
                dirDist += (firstDir - secondDir) * (firstDir - secondDir);
            }
            return new double[]{Math.sqrt(dist), Math.sqrt(dirDist)};
        }

        public double crossValidateForAlpha(Instances tr) {
            Instances train = tr;
            if (this.sampleTrain) {
                tr = InstanceTools.subSample(tr, tr.numInstances() / 10, 0);
            }
            double[] labels = new double[train.numInstances()];
            for (int i = 0; i < train.numInstances(); ++i) {
                labels[i] = train.instance(i).classValue();
            }
            double[] a = new double[ALPHAS.length];
            double[] b = new double[ALPHAS.length];
            for (int alphaId = 0; alphaId < ALPHAS.length; ++alphaId) {
                a[alphaId] = Math.cos(ALPHAS[alphaId]);
                b[alphaId] = Math.sin(ALPHAS[alphaId]);
            }
            int n = train.numInstances();
            int k = ALPHAS.length;
            int[] mistakes = new int[k];
            for (int i = 0; i < n; ++i) {
                int j;
                double[] D = new double[k];
                double[] L = new double[k];
                for (j = 0; j < k; ++j) {
                    D[j] = Double.MAX_VALUE;
                }
                for (j = 0; j < n; ++j) {
                    if (i == j) continue;
                    double[] individualDistances = this.getNonScaledDistances(train.instance(i), train.instance(j));
                    double dist = individualDistances[0];
                    double dDist = individualDistances[1];
                    double[] d = new double[k];
                    for (int alphaId = 0; alphaId < k; ++alphaId) {
                        d[alphaId] = a[alphaId] * dist + b[alphaId] * dDist;
                        if (!(d[alphaId] < D[alphaId])) continue;
                        D[alphaId] = d[alphaId];
                        L[alphaId] = labels[j];
                    }
                }
                for (int alphaId = 0; alphaId < k; ++alphaId) {
                    if (L[alphaId] == labels[i]) continue;
                    int n2 = alphaId;
                    mistakes[n2] = mistakes[n2] + 1;
                }
            }
            int bsfMistakes = Integer.MAX_VALUE;
            int bsfAlphaId = -1;
            for (int alpha = 0; alpha < k; ++alpha) {
                if (mistakes[alpha] >= bsfMistakes) continue;
                bsfMistakes = mistakes[alpha];
                bsfAlphaId = alpha;
            }
            this.alpha = ALPHAS[bsfAlphaId];
            this.a = Math.cos(this.alpha);
            this.b = Math.sin(this.alpha);
            return (double)(train.numInstances() - bsfMistakes) / (double)train.numInstances();
        }

        public double[] crossValidateForAandB(Instances tr) {
            Instances train = tr;
            if (this.sampleTrain) {
                tr = InstanceTools.subSample(tr, tr.numInstances() / 10, 0);
            }
            double[] labels = new double[train.numInstances()];
            for (int i = 0; i < train.numInstances(); ++i) {
                labels[i] = train.instance(i).classValue();
            }
            double[] a = new double[101];
            double[] b = new double[101];
            for (int alphaId = 0; alphaId <= 100; ++alphaId) {
                a[alphaId] = (double)(100 - alphaId) / 100.0;
                b[alphaId] = (double)alphaId / 100.0;
            }
            int n = train.numInstances();
            int k = a.length;
            int[] mistakes = new int[k];
            double[][] LforAll = new double[n][];
            for (int i = 0; i < n; ++i) {
                int j;
                double[] D = new double[k];
                double[] L = new double[k];
                for (j = 0; j < k; ++j) {
                    D[j] = Double.MAX_VALUE;
                }
                for (j = 0; j < n; ++j) {
                    if (i == j) continue;
                    double[] individualDistances = this.getNonScaledDistances(train.instance(i), train.instance(j));
                    double dist = individualDistances[0];
                    double dDist = individualDistances[1];
                    double[] d = new double[k];
                    for (int alphaId = 0; alphaId < k; ++alphaId) {
                        d[alphaId] = a[alphaId] * dist + b[alphaId] * dDist;
                        if (!(d[alphaId] < D[alphaId])) continue;
                        D[alphaId] = d[alphaId];
                        L[alphaId] = labels[j];
                    }
                }
                for (int alphaId = 0; alphaId < k; ++alphaId) {
                    if (L[alphaId] == labels[i]) continue;
                    int n2 = alphaId;
                    mistakes[n2] = mistakes[n2] + 1;
                }
                LforAll[i] = L;
            }
            int bsfMistakes = Integer.MAX_VALUE;
            int bsfAlphaId = -1;
            for (int alpha = 0; alpha < k; ++alpha) {
                if (mistakes[alpha] >= bsfMistakes) continue;
                bsfMistakes = mistakes[alpha];
                bsfAlphaId = alpha;
            }
            this.alpha = -1.0;
            this.a = a[bsfAlphaId];
            this.b = b[bsfAlphaId];
            double[] bestAlphaPredictions = new double[train.numInstances()];
            for (int i = 0; i < bestAlphaPredictions.length; ++i) {
                bestAlphaPredictions[i] = LforAll[i][bsfAlphaId];
            }
            return bestAlphaPredictions;
        }

        public double getA() {
            return this.a;
        }

        public double getB() {
            return this.b;
        }
    }

    public static enum DistanceType {
        EUCLIDEAN,
        DTW;

    }
}

