/*
 * Decompiled with CFR 0.152.
 */
package timeseriesweka.classifiers.ensembles.elastic_ensemble;

import java.util.Comparator;
import java.util.PriorityQueue;
import timeseriesweka.classifiers.ensembles.elastic_ensemble.DTW1NN;
import utilities.GenericTools;
import utilities.generic_storage.Pair;
import weka.core.Instance;

public class DTWKNN
extends DTW1NN {
    public int k;
    private static Comparator<Pair<Double, Integer>> comparator = new Comparator<Pair<Double, Integer>>(){

        @Override
        public int compare(Pair<Double, Integer> o1, Pair<Double, Integer> o2) {
            return ((Double)o1.var1).compareTo((Double)o2.var1) * -1;
        }
    };

    public DTWKNN() {
        this.k = 5;
    }

    public DTWKNN(int k) {
        this.k = k;
    }

    public DTWKNN(double r) {
        super(r);
        this.k = 5;
    }

    public DTWKNN(double r, int k) {
        super(r);
        this.k = k;
    }

    @Override
    public double classifyInstance(Instance instance) throws Exception {
        return GenericTools.indexOfMax(this.distributionForInstance(instance));
    }

    @Override
    public double[] distributionForInstance(Instance testInst) throws Exception {
        PriorityQueue<Pair<Double, Integer>> topK = new PriorityQueue<Pair<Double, Integer>>(this.k, comparator);
        double thisDist = this.distance(testInst, this.train.instance(0), Double.MAX_VALUE);
        topK.add(new Pair<Double, Integer>(thisDist, (int)this.train.instance(0).classValue()));
        for (int i = 1; i < this.train.numInstances(); ++i) {
            Instance trainInst = this.train.instance(i);
            thisDist = this.distance(testInst, trainInst, (double)((Double)topK.peek().var1));
            if (topK.size() < this.k) {
                topK.add(new Pair<Double, Integer>(thisDist, (int)trainInst.classValue()));
                continue;
            }
            if (!(thisDist < (Double)topK.peek().var1)) continue;
            topK.poll();
            topK.add(new Pair<Double, Integer>(thisDist, (int)trainInst.classValue()));
        }
        double distanceSum = 0.0;
        for (Pair<Double, Integer> pair : topK) {
            distanceSum += ((Double)pair.var1).doubleValue();
        }
        double[] distribution = new double[this.train.numClasses()];
        double distanceSum2 = 0.0;
        for (Pair<Double, Integer> pair : topK) {
            pair.var1 = 1.0 - (Double)pair.var1 / distanceSum;
            distanceSum2 += ((Double)pair.var1).doubleValue();
        }
        for (Pair<Double, Integer> pair : topK) {
            int n = (Integer)pair.var2;
            distribution[n] = distribution[n] + (Double)pair.var1 / distanceSum2;
        }
        return distribution;
    }
}

