/*
 * Decompiled with CFR 0.152.
 */
package timeseriesweka.classifiers;

import java.io.File;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Random;
import timeseriesweka.classifiers.AbstractClassifierWithTrainingData;
import timeseriesweka.classifiers.ParameterSplittable;
import utilities.ClassifierTools;
import utilities.InstanceTools;
import utilities.StatisticalUtilities;
import weka.clusterers.SimpleKMeans;
import weka.core.Capabilities;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.TechnicalInformation;

public class LearnShapelets
extends AbstractClassifierWithTrainingData
implements ParameterSplittable {
    boolean suppressOutput = false;
    long seed;
    public int seriesLength;
    public int[] L;
    public int K;
    public int C;
    public int[] numberOfSegments;
    int L_min;
    double[][][] Shapelets;
    double[][][] W;
    double[] biasW;
    double[][][] GradHistShapelets;
    double[][][] GradHistW;
    double[] GradHistBiasW;
    public double lambdaW = 0.01;
    public int R = 3;
    public double percentageOfSeriesLength = 0.2;
    public double eta = 0.1;
    public double alpha = -30.0;
    public int maxIter = 300;
    public Instances trainSet;
    public double[][] train;
    public double[][] classValues_train;
    public List<Double> nominalLabels;
    double[][][][] D_train;
    double[][][][] E_train;
    double[][][] M_train;
    double[][][] Psi_train;
    double[][] sigY_train;
    double[][][] D_test;
    double[][][] E_test;
    double[][] M_test;
    double[][] Psi_test;
    double[] sigY_test;
    double[][] tmp2;
    double regWConst;
    double tmp1;
    double tmp3;
    double dLdY;
    double gradW_crk;
    double gradS_rkl;
    double gradBiasW_c;
    double eps = 1.0E-21;
    Random rand = new Random();
    List<List<Integer>> posIdxs;
    List<List<Integer>> negIdxs;
    List<Integer> instanceIdxs;
    public boolean enableParallel = true;
    boolean paraSearch = false;
    double[] lambdaWRange = new double[]{0.01, 0.1};
    double[] percentageOfSeriesLengthRange = new double[]{0.15};
    int[] shapeletLengthScaleRange = new int[]{2, 4};
    double maxAcc;

    public TechnicalInformation getTechnicalInformation() {
        TechnicalInformation result = new TechnicalInformation(TechnicalInformation.Type.ARTICLE);
        result.setValue(TechnicalInformation.Field.AUTHOR, "J. Grabocka, N. Schilling, M. Wistuba and L. Schmidt-Thieme");
        result.setValue(TechnicalInformation.Field.TITLE, "Learning Time-Series Shapelets");
        result.setValue(TechnicalInformation.Field.JOURNAL, "Proc. 20th SIGKDD");
        result.setValue(TechnicalInformation.Field.YEAR, "2014");
        return result;
    }

    @Override
    public void setParamSearch(boolean b) {
        this.paraSearch = b;
    }

    public void fixParameters() {
        this.lambdaW = 0.01;
        this.R = 3;
        this.percentageOfSeriesLength = 0.2;
        this.eta = 0.1;
        this.alpha = -30.0;
        this.maxIter = 300;
    }

    @Override
    public void setParametersFromIndex(int x) {
        this.lambdaW = x <= 4 ? this.lambdaWRange[0] : this.lambdaWRange[1];
        this.percentageOfSeriesLength = x == 1 || x == 2 || x == 5 || x == 6 ? this.percentageOfSeriesLengthRange[0] : this.percentageOfSeriesLengthRange[1];
        this.R = x % 2 == 1 ? this.shapeletLengthScaleRange[0] : this.shapeletLengthScaleRange[1];
    }

    @Override
    public String getParas() {
        return this.lambdaW + "," + this.percentageOfSeriesLength + "," + this.R;
    }

    @Override
    public double getAcc() {
        return this.maxAcc;
    }

    public void setSeed(long seed) {
        this.seed = seed;
        this.rand = new Random(seed);
    }

    public void initialize() throws Exception {
        int k;
        int r;
        int i;
        int r2;
        if (this.K == 0) {
            this.K = 1;
        }
        this.L_min = (int)(this.percentageOfSeriesLength * (double)this.seriesLength);
        this.createOneVsAllTargets();
        this.Shapelets = new double[this.R][][];
        this.numberOfSegments = new int[this.R];
        this.L = new int[this.R];
        int totalSegments = 0;
        for (r2 = 0; r2 < this.R; ++r2) {
            this.L[r2] = (r2 + 1) * this.L_min;
            this.numberOfSegments[r2] = this.seriesLength - this.L[r2];
            totalSegments += this.train.length * this.numberOfSegments[r2];
        }
        this.K = (int)(Math.log(totalSegments) * (double)(this.C - 1));
        this.D_train = new double[this.train.length][this.R][this.K][];
        this.E_train = new double[this.train.length][this.R][this.K][];
        for (i = 0; i < this.train.length; ++i) {
            for (r = 0; r < this.R; ++r) {
                for (k = 0; k < this.K; ++k) {
                    this.D_train[i][r][k] = new double[this.numberOfSegments[r]];
                    this.E_train[i][r][k] = new double[this.numberOfSegments[r]];
                }
            }
        }
        this.M_train = new double[this.train.length][this.R][this.K];
        this.Psi_train = new double[this.train.length][this.R][this.K];
        this.sigY_train = new double[this.train.length][this.C];
        this.W = new double[this.C][this.R][this.K];
        this.biasW = new double[this.C];
        for (int c = 0; c < this.C; ++c) {
            for (r = 0; r < this.R; ++r) {
                for (k = 0; k < this.K; ++k) {
                    this.W[c][r][k] = 2.0 * this.eps * this.rand.nextDouble() - 1.0;
                }
            }
            this.biasW[c] = 2.0 * this.eps * this.rand.nextDouble() - 1.0;
        }
        this.GradHistW = new double[this.C][this.R][this.K];
        this.GradHistBiasW = new double[this.C];
        this.GradHistShapelets = new double[this.R][][];
        for (r2 = 0; r2 < this.R; ++r2) {
            this.GradHistShapelets[r2] = new double[this.K][this.L[r2]];
        }
        this.initializeShapeletsKMeans();
        this.print("Initialization completed: L_min=" + this.L_min + ", K=" + this.K + ", R=" + this.R + ", C=" + this.C + ", lambdaW=" + this.lambdaW);
        this.tmp2 = new double[this.R][];
        for (r2 = 0; r2 < this.R; ++r2) {
            this.tmp2[r2] = new double[this.numberOfSegments[r2]];
        }
        this.regWConst = 2.0 * this.lambdaW / (double)this.train.length;
        this.instanceIdxs = new ArrayList<Integer>();
        for (i = 0; i < this.train.length; ++i) {
            this.instanceIdxs.add(i);
        }
    }

    public void createOneVsAllTargets() {
        this.C = this.nominalLabels.size();
        this.classValues_train = new double[this.train.length][this.C];
        for (int i = 0; i < this.train.length; ++i) {
            for (int c = 0; c < this.C; ++c) {
                this.classValues_train[i][c] = 0.0;
            }
            int indexLabel = this.nominalLabels.indexOf(this.trainSet.get(i).classValue());
            this.classValues_train[i][indexLabel] = 1.0;
        }
        this.posIdxs = new ArrayList<List<Integer>>();
        this.negIdxs = new ArrayList<List<Integer>>();
        for (int c = 0; c < this.C; ++c) {
            ArrayList<Integer> posIdx_c = new ArrayList<Integer>();
            ArrayList<Integer> negIdx_c = new ArrayList<Integer>();
            for (int i = 0; i < this.train.length; ++i) {
                if (this.classValues_train[i][c] == 1.0) {
                    posIdx_c.add(i);
                    continue;
                }
                negIdx_c.add(i);
            }
            this.posIdxs.add(posIdx_c);
            this.negIdxs.add(negIdx_c);
        }
    }

    public void initializeShapeletsKMeans() throws Exception {
        for (int r = 0; r < this.R; ++r) {
            int j;
            int i;
            double[][] segments_r = new double[this.train.length * this.numberOfSegments[r]][this.L[r]];
            for (i = 0; i < this.train.length; ++i) {
                for (j = 0; j < this.numberOfSegments[r]; ++j) {
                    for (int l = 0; l < this.L[r]; ++l) {
                        segments_r[i * this.numberOfSegments[r] + j][l] = this.train[i][j + l];
                    }
                }
            }
            for (i = 0; i < this.train.length; ++i) {
                for (j = 0; j < this.numberOfSegments[r]; ++j) {
                    segments_r[i * this.numberOfSegments[r] + j] = StatisticalUtilities.normalize(segments_r[i * this.numberOfSegments[r] + j]);
                }
            }
            Instances ins = InstanceTools.toWekaInstances(segments_r);
            SimpleKMeans skm = new SimpleKMeans();
            skm.setNumClusters(this.K);
            skm.setMaxIterations(100);
            skm.setSeed((int)(this.rand.nextDouble() * 1000.0));
            skm.buildClusterer(ins);
            Instances centroidsWeka = skm.getClusterCentroids();
            this.Shapelets[r] = InstanceTools.fromWekaInstancesArray(centroidsWeka, false);
            if (this.Shapelets[r] != null) continue;
            this.print("P not set");
        }
    }

    public double predict_i(double[][] M, int c) {
        double Y_hat_ic = this.biasW[c];
        for (int r = 0; r < this.R; ++r) {
            for (int k = 0; k < this.K; ++k) {
                Y_hat_ic += M[r][k] * this.W[c][r][k];
            }
        }
        return Y_hat_ic;
    }

    public void preCompute(double[][][] D, double[][][] E, double[][] Psi, double[][] M, double[] sigY, double[] series) {
        for (int r = 0; r < this.R; ++r) {
            for (int k = 0; k < this.Shapelets[r].length; ++k) {
                int j;
                for (j = 0; j < this.numberOfSegments[r]; ++j) {
                    D[r][k][j] = 0.0;
                    double err = 0.0;
                    for (int l = 0; l < this.L[r]; ++l) {
                        err = series[j + l] - this.Shapelets[r][k][l];
                        double[] dArray = D[r][k];
                        int n = j;
                        dArray[n] = dArray[n] + err * err;
                    }
                    double[] dArray = D[r][k];
                    int n = j;
                    dArray[n] = dArray[n] / (double)this.L[r];
                    E[r][k][j] = Math.exp(this.alpha * D[r][k][j]);
                }
                Psi[r][k] = 0.0;
                for (j = 0; j < this.numberOfSegments[r]; ++j) {
                    double[] dArray = Psi[r];
                    int n = k;
                    dArray[n] = dArray[n] + Math.exp(this.alpha * D[r][k][j]);
                }
                M[r][k] = 0.0;
                for (j = 0; j < this.numberOfSegments[r]; ++j) {
                    double[] dArray = M[r];
                    int n = k;
                    dArray[n] = dArray[n] + D[r][k][j] * E[r][k][j];
                }
                double[] dArray = M[r];
                int n = k;
                dArray[n] = dArray[n] / Psi[r][k];
            }
        }
        for (int c = 0; c < this.C; ++c) {
            sigY[c] = StatisticalUtilities.calculateSigmoid(this.predict_i(M, c));
        }
    }

    public double accuracyLoss(double[][] M, double[] classValues, int c) {
        double Y_hat_ic = this.predict_i(M, c);
        double sig_y_ic = StatisticalUtilities.calculateSigmoid(Y_hat_ic);
        double returnVal = -classValues[c] * Math.log(sig_y_ic) - (1.0 - classValues[c]) * Math.log(1.0 - sig_y_ic);
        return returnVal;
    }

    public double accuracyLossTrainSet() {
        double accuracyLoss = 0.0;
        for (int i = 0; i < this.train.length; ++i) {
            this.preCompute(this.D_train[i], this.E_train[i], this.Psi_train[i], this.M_train[i], this.sigY_train[i], this.train[i]);
            for (int c = 0; c < this.C; ++c) {
                accuracyLoss += this.accuracyLoss(this.M_train[i], this.classValues_train[i], c);
            }
        }
        return accuracyLoss / (double)this.train.length;
    }

    public void learnF(int c, int i) {
        this.preCompute(this.D_train[i], this.E_train[i], this.Psi_train[i], this.M_train[i], this.sigY_train[i], this.train[i]);
        this.dLdY = -(this.classValues_train[i][c] - this.sigY_train[i][c]);
        for (int r = 0; r < this.R; ++r) {
            for (int k = 0; k < this.Shapelets[r].length; ++k) {
                this.gradW_crk = this.dLdY * this.M_train[i][r][k] + this.regWConst * this.W[c][r][k];
                double[] dArray = this.GradHistW[c][r];
                int n = k;
                dArray[n] = dArray[n] + this.gradW_crk * this.gradW_crk;
                double[] dArray2 = this.W[c][r];
                int n2 = k;
                dArray2[n2] = dArray2[n2] - this.eta / (Math.sqrt(this.GradHistW[c][r][k]) + this.eps) * this.gradW_crk;
                this.tmp1 = 2.0 / ((double)this.L[r] * this.Psi_train[i][r][k]);
                for (int j = 0; j < this.numberOfSegments[r]; ++j) {
                    this.tmp2[r][j] = this.E_train[i][r][k][j] * (1.0 + this.alpha * (this.D_train[i][r][k][j] - this.M_train[i][r][k]));
                }
                for (int l = 0; l < this.L[r]; ++l) {
                    this.tmp3 = 0.0;
                    for (int j = 0; j < this.numberOfSegments[r]; ++j) {
                        this.tmp3 += this.tmp2[r][j] * (this.Shapelets[r][k][l] - this.train[i][j + l]);
                    }
                    this.gradS_rkl = this.dLdY * this.W[c][r][k] * this.tmp1 * this.tmp3;
                    double[] dArray3 = this.GradHistShapelets[r][k];
                    int n3 = l;
                    dArray3[n3] = dArray3[n3] + this.gradS_rkl * this.gradS_rkl;
                    double[] dArray4 = this.Shapelets[r][k];
                    int n4 = l;
                    dArray4[n4] = dArray4[n4] - this.eta / (Math.sqrt(this.GradHistShapelets[r][k][l]) + this.eps) * this.gradS_rkl;
                }
            }
        }
        this.gradBiasW_c = this.dLdY;
        int n = c;
        this.GradHistBiasW[n] = this.GradHistBiasW[n] + this.gradBiasW_c * this.gradBiasW_c;
        int n5 = c;
        this.biasW[n5] = this.biasW[n5] - this.eta / (Math.sqrt(this.GradHistBiasW[c]) + this.eps) * this.gradBiasW_c;
    }

    public void learnF() {
        for (int c = 0; c < this.C; ++c) {
            for (int i = 0; i < this.train.length; ++i) {
                int posIdx = this.posIdxs.get(c).get(this.rand.nextInt(this.posIdxs.get(c).size()));
                int negIdx = this.negIdxs.get(c).get(this.rand.nextInt(this.negIdxs.get(c).size()));
                this.learnF(c, posIdx);
                this.learnF(c, negIdx);
            }
        }
    }

    @Override
    public void buildClassifier(Instances trainData) throws Exception {
        this.trainResults.buildTime = System.currentTimeMillis();
        if (this.paraSearch) {
            double[] paramsLambdaW = this.lambdaWRange;
            double[] paramsPercentageOfSeriesLength = this.percentageOfSeriesLengthRange;
            int[] paramsShapeletLengthScale = this.shapeletLengthScaleRange;
            int noFolds = 2;
            double bsfAccuracy = 0.0;
            int[] params = new int[]{0, 0, 0};
            double accuracy = 0.0;
            trainData.randomize(this.rand);
            trainData.stratify(noFolds);
            int numHpsCombinations = 1;
            for (int i = 0; i < paramsLambdaW.length; ++i) {
                for (int j = 0; j < paramsPercentageOfSeriesLength.length; ++j) {
                    for (int k = 0; k < paramsShapeletLengthScale.length; ++k) {
                        this.percentageOfSeriesLength = paramsPercentageOfSeriesLength[j];
                        this.R = paramsShapeletLengthScale[k];
                        this.lambdaW = paramsLambdaW[i];
                        this.print("HPS Combination #" + numHpsCombinations + ": {R=" + this.R + ", L=" + this.percentageOfSeriesLength + ", lambdaW=" + this.lambdaW + "}");
                        this.print("--------------------------------------");
                        double sumAccuracy = 0.0;
                        for (int l = 0; l < noFolds; ++l) {
                            Instances trainCV = trainData.trainCV(noFolds, l);
                            Instances testCV = trainData.testCV(noFolds, l);
                            this.eta = 0.1;
                            this.alpha = -30.0;
                            this.maxIter = 300;
                            this.print("Learn model for Fold-" + l + ":");
                            this.train(trainCV);
                            accuracy = ClassifierTools.accuracy(testCV, this);
                            sumAccuracy += accuracy;
                            this.print("Accuracy-Fold-" + l + " = " + accuracy);
                            trainCV = null;
                            testCV = null;
                        }
                        this.print("Accuracy-CV = " + (sumAccuracy /= (double)noFolds));
                        this.print("--------------------------------------");
                        if (sumAccuracy > bsfAccuracy) {
                            int[] p = new int[]{i, j, k};
                            params = p;
                            bsfAccuracy = sumAccuracy;
                        }
                        ++numHpsCombinations;
                    }
                }
            }
            System.gc();
            this.maxAcc = bsfAccuracy;
            this.lambdaW = paramsLambdaW[params[0]];
            this.percentageOfSeriesLength = paramsPercentageOfSeriesLength[params[1]];
            this.R = paramsShapeletLengthScale[params[2]];
            this.eta = 0.1;
            this.alpha = -30.0;
            this.maxIter = 600;
            this.print("Learn final model with best hyper-parameters: R=" + this.R + ", L=" + this.percentageOfSeriesLength + ", lambdaW=" + this.lambdaW);
        } else {
            this.fixParameters();
            this.print("Fixed parameters: R=" + this.R + ", L=" + this.percentageOfSeriesLength + ", lambdaW=" + this.lambdaW);
        }
        this.train(trainData);
        this.trainResults.buildTime = System.currentTimeMillis() - this.trainResults.buildTime;
    }

    private void train(Instances data) throws Exception {
        this.trainSet = data;
        this.seriesLength = this.trainSet.numAttributes() - 1;
        this.nominalLabels = LearnShapelets.readNominalTargets(this.trainSet);
        if (this.nominalLabels.size() < 2) {
            System.err.println("Fatal error: Number of classes is " + this.nominalLabels.size());
            return;
        }
        this.train = InstanceTools.fromWekaInstancesArray(this.trainSet, true);
        this.initialize();
        for (int iter = 0; iter <= this.maxIter; ++iter) {
            this.learnF();
            if (iter % (this.maxIter / 3) != 0 || iter <= 0) continue;
            double lossTrain = this.accuracyLossTrainSet();
            this.print("Iter=" + iter + ", Loss=" + lossTrain);
            if (Double.isNaN(lossTrain)) break;
        }
    }

    @Override
    public double classifyInstance(Instance instance) throws Exception {
        double[] temp = instance.toDoubleArray();
        double[] test = new double[temp.length - 1];
        System.arraycopy(temp, 0, test, 0, temp.length - 1);
        test = StatisticalUtilities.normalize(test);
        this.D_test = new double[this.R][this.K][];
        this.E_test = new double[this.R][this.K][];
        for (int r = 0; r < this.R; ++r) {
            for (int k = 0; k < this.K; ++k) {
                this.D_test[r][k] = new double[this.numberOfSegments[r]];
                this.E_test[r][k] = new double[this.numberOfSegments[r]];
            }
        }
        this.M_test = new double[this.R][this.K];
        this.Psi_test = new double[this.R][this.K];
        this.sigY_test = new double[this.C];
        this.preCompute(this.D_test, this.E_test, this.Psi_test, this.M_test, this.sigY_test, test);
        double max_Y_hat_ic = Double.MIN_VALUE;
        int label_i = 0;
        for (int c = 0; c < this.C; ++c) {
            double Y_hat_ic = StatisticalUtilities.calculateSigmoid(this.predict_i(this.M_test, c));
            if (!(Y_hat_ic > max_Y_hat_ic)) continue;
            max_Y_hat_ic = Y_hat_ic;
            label_i = c;
        }
        return this.nominalLabels.get(label_i);
    }

    public void suppressOutput() {
        this.suppressOutput = true;
    }

    void print(String s) {
        if (!this.suppressOutput) {
            System.out.println(s);
        }
    }

    @Override
    public Capabilities getCapabilities() {
        throw new UnsupportedOperationException("Not supported yet.");
    }

    public static ArrayList<Double> readNominalTargets(Instances instances) {
        if (instances.size() <= 0) {
            return null;
        }
        ArrayList<Double> nominalLabels = new ArrayList<Double>();
        for (Instance ins : instances) {
            boolean alreadyAdded = false;
            for (Double nominalLabel : nominalLabels) {
                if (nominalLabel.doubleValue() != ins.classValue()) continue;
                alreadyAdded = true;
                break;
            }
            if (alreadyAdded) continue;
            nominalLabels.add(ins.classValue());
        }
        Collections.sort(nominalLabels);
        return nominalLabels;
    }

    public static void main(String[] args) throws Exception {
        if (args.length == 0) {
            args = new String[]{"C:\\LocalData\\Dropbox\\TSC Problems", "OliveOil"};
        }
        String dataset = args[1];
        String fileExtension = File.separator + dataset + File.separator + dataset;
        String samplePath = args[0] + fileExtension;
        Instances testSet = ClassifierTools.loadData(samplePath + "_TEST");
        Instances trainSet = ClassifierTools.loadData(samplePath + "_TRAIN");
        LearnShapelets ls = new LearnShapelets();
        ls.setSeed(0L);
        ls.buildClassifier(trainSet);
        double accuracy = ClassifierTools.accuracy(testSet, ls);
        System.out.println(dataset + ", LS= " + (1.0 - accuracy));
    }
}

