/*
 * Decompiled with CFR 0.152.
 */
package timeseriesweka.classifiers;

import java.io.File;
import java.io.FileWriter;
import java.util.ArrayList;
import java.util.Random;
import java.util.Scanner;
import timeseriesweka.classifiers.AbstractClassifierWithTrainingData;
import timeseriesweka.classifiers.cote.HiveCoteModule;
import timeseriesweka.classifiers.ensembles.elastic_ensemble.DTW1NN;
import timeseriesweka.classifiers.ensembles.elastic_ensemble.ED1NN;
import timeseriesweka.classifiers.ensembles.elastic_ensemble.ERP1NN;
import timeseriesweka.classifiers.ensembles.elastic_ensemble.Efficient1NN;
import timeseriesweka.classifiers.ensembles.elastic_ensemble.LCSS1NN;
import timeseriesweka.classifiers.ensembles.elastic_ensemble.MSM1NN;
import timeseriesweka.classifiers.ensembles.elastic_ensemble.TWE1NN;
import timeseriesweka.classifiers.ensembles.elastic_ensemble.WDTW1NN;
import timeseriesweka.filters.DerivativeFilter;
import utilities.ClassifierResults;
import utilities.ClassifierTools;
import utilities.TrainAccuracyEstimate;
import utilities.WritableTestResults;
import weka.core.Capabilities;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.TechnicalInformation;

public class ElasticEnsemble
extends AbstractClassifierWithTrainingData
implements HiveCoteModule,
WritableTestResults,
TrainAccuracyEstimate {
    private final ConstituentClassifiers[] classifiersToUse;
    private String datasetName;
    private int resampleId;
    private String resultsDir;
    private double[] cvAccs;
    private double[][] cvPreds;
    private boolean buildFromFile = false;
    private boolean writeToFile = false;
    private Instances train;
    private Instances derTrain;
    private Efficient1NN[] classifiers = null;
    private boolean writeEnsembleTrainingFile = false;
    private String ensembleTrainFilePathAndName = null;
    private boolean usesDer = false;
    private static DerivativeFilter df = new DerivativeFilter();
    double[] previousPredictions = null;
    double ensembleCvAcc = -1.0;
    double[] ensembleCvPreds = null;

    public TechnicalInformation getTechnicalInformation() {
        TechnicalInformation result = new TechnicalInformation(TechnicalInformation.Type.ARTICLE);
        result.setValue(TechnicalInformation.Field.AUTHOR, "J. Lines and A. Bagnall");
        result.setValue(TechnicalInformation.Field.TITLE, "Time Series Classification with Ensembles of Elastic Distance Measures");
        result.setValue(TechnicalInformation.Field.JOURNAL, "Data Mining and Knowledge Discovery");
        result.setValue(TechnicalInformation.Field.VOLUME, "29");
        result.setValue(TechnicalInformation.Field.NUMBER, "3");
        result.setValue(TechnicalInformation.Field.PAGES, "565-592");
        result.setValue(TechnicalInformation.Field.YEAR, "2015");
        return result;
    }

    public static boolean isDerivative(ConstituentClassifiers classifier) {
        return classifier == ConstituentClassifiers.DDTW_R1_1NN || classifier == ConstituentClassifiers.DDTW_Rn_1NN || classifier == ConstituentClassifiers.WDDTW_1NN;
    }

    public static boolean isFixedParam(ConstituentClassifiers classifier) {
        return classifier == ConstituentClassifiers.DDTW_R1_1NN || classifier == ConstituentClassifiers.DTW_R1_1NN || classifier == ConstituentClassifiers.Euclidean_1NN;
    }

    @Override
    public Capabilities getCapabilities() {
        throw new UnsupportedOperationException("Not supported yet.");
    }

    public String[] getIndividualClassifierNames() {
        String[] names = new String[this.classifiersToUse.length];
        for (int i = 0; i < this.classifiersToUse.length; ++i) {
            names[i] = this.classifiersToUse[i].toString();
        }
        return names;
    }

    public double[] getIndividualCVAccs() {
        return this.cvAccs;
    }

    @Override
    public double getEnsembleCvAcc() {
        if (this.ensembleCvAcc != -1.0 && this.ensembleCvPreds != null) {
            return this.ensembleCvAcc;
        }
        this.getEnsembleCvPreds();
        return this.ensembleCvAcc;
    }

    @Override
    public double[] getEnsembleCvPreds() {
        if (this.ensembleCvPreds != null) {
            return this.ensembleCvPreds;
        }
        this.ensembleCvPreds = new double[this.train.numInstances()];
        int correct = 0;
        for (int i = 0; i < this.train.numInstances(); ++i) {
            double actual = this.train.instance(i).classValue();
            ArrayList<Double> bsfClassVals = null;
            double bsfWeight = -1.0;
            double[] weightByClass = new double[this.train.numClasses()];
            for (int c = 0; c < this.classifiers.length; ++c) {
                int n = (int)this.cvPreds[c][i];
                weightByClass[n] = weightByClass[n] + this.cvAccs[c];
                if (weightByClass[(int)this.cvPreds[c][i]] > bsfWeight) {
                    bsfWeight = weightByClass[(int)this.cvPreds[c][i]];
                    bsfClassVals = new ArrayList<Double>();
                    bsfClassVals.add(this.cvPreds[c][i]);
                    continue;
                }
                if (weightByClass[(int)this.cvPreds[c][i]] != bsfWeight) continue;
                bsfClassVals.add(this.cvPreds[c][i]);
            }
            double pred = bsfClassVals.size() > 1 ? ((Double)bsfClassVals.get(new Random().nextInt(bsfClassVals.size()))).doubleValue() : ((Double)bsfClassVals.get(0)).doubleValue();
            if (pred == actual) {
                ++correct;
            }
            this.ensembleCvPreds[i] = pred;
        }
        this.ensembleCvAcc = (double)correct / (double)this.train.numInstances();
        return this.ensembleCvPreds;
    }

    public double[] getIndividualCvAccs() {
        return this.cvAccs;
    }

    public double[][] getIndividualCvPredictions() {
        return this.cvPreds;
    }

    public ElasticEnsemble() {
        this.classifiersToUse = ConstituentClassifiers.values();
    }

    public ElasticEnsemble(ConstituentClassifiers[] classifiersToUse) {
        this.classifiersToUse = classifiersToUse;
    }

    public ElasticEnsemble(String resultsDir, String datasetName, int resampleId) {
        this.resultsDir = resultsDir;
        this.datasetName = datasetName;
        this.resampleId = resampleId;
        this.classifiersToUse = ConstituentClassifiers.values();
        this.buildFromFile = true;
    }

    public ElasticEnsemble(String resultsDir, String datasetName, int resampleId, ConstituentClassifiers[] classifiersToUse) {
        this.resultsDir = resultsDir;
        this.datasetName = datasetName;
        this.resampleId = resampleId;
        this.classifiersToUse = classifiersToUse;
        this.buildFromFile = true;
    }

    public void setInternalFileWritingOn(String resultsDir, String datasetName, int resampleId) {
        this.resultsDir = resultsDir;
        this.datasetName = datasetName;
        this.resampleId = resampleId;
        this.writeToFile = true;
    }

    @Override
    public void writeCVTrainToFile(String outputPathAndName) {
        this.writeEnsembleTrainingFile = true;
        this.ensembleTrainFilePathAndName = outputPathAndName;
    }

    @Override
    public boolean findsTrainAccuracyEstimate() {
        return this.writeEnsembleTrainingFile;
    }

    @Override
    public ClassifierResults getTrainResults() {
        this.trainResults.acc = this.ensembleCvAcc;
        return this.trainResults;
    }

    @Override
    public void buildClassifier(Instances train) throws Exception {
        this.trainResults.buildTime = System.currentTimeMillis();
        this.train = train;
        this.derTrain = null;
        this.usesDer = false;
        this.classifiers = new Efficient1NN[this.classifiersToUse.length];
        this.cvAccs = new double[this.classifiers.length];
        this.cvPreds = new double[this.classifiers.length][this.train.numInstances()];
        for (int c = 0; c < this.classifiers.length; ++c) {
            this.classifiers[c] = ElasticEnsemble.getClassifier(this.classifiersToUse[c]);
            if (!ElasticEnsemble.isDerivative(this.classifiersToUse[c])) continue;
            this.usesDer = true;
        }
        if (this.usesDer) {
            this.derTrain = df.process(train);
        }
        if (this.buildFromFile) {
            for (int c = 0; c < this.classifiers.length; ++c) {
                File existingTrainOut = new File(this.resultsDir + (Object)((Object)this.classifiersToUse[c]) + "/Predictions/" + this.datasetName + "/trainFold" + this.resampleId + ".csv");
                if (!existingTrainOut.exists()) {
                    throw new Exception("Error: training file doesn't exist for " + existingTrainOut.getAbsolutePath());
                }
                Scanner scan = new Scanner(existingTrainOut);
                scan.useDelimiter("\n");
                scan.next();
                int paramId = Integer.parseInt(scan.next().trim().split(",")[0]);
                double cvAcc = Double.parseDouble(scan.next().trim().split(",")[0]);
                for (int i = 0; i < train.numInstances(); ++i) {
                    this.cvPreds[c][i] = Double.parseDouble(scan.next().split(",")[1]);
                }
                scan.close();
                if (ElasticEnsemble.isDerivative(this.classifiersToUse[c])) {
                    if (!ElasticEnsemble.isFixedParam(this.classifiersToUse[c])) {
                        this.classifiers[c].setParamsFromParamId(this.derTrain, paramId);
                    }
                    this.classifiers[c].buildClassifier(this.derTrain);
                } else {
                    if (!ElasticEnsemble.isFixedParam(this.classifiersToUse[c])) {
                        this.classifiers[c].setParamsFromParamId(train, paramId);
                    }
                    this.classifiers[c].buildClassifier(train);
                }
                this.cvAccs[c] = cvAcc;
            }
        } else {
            for (int c = 0; c < this.classifiers.length; ++c) {
                if (this.writeToFile) {
                    this.classifiers[c].setFileWritingOn(this.resultsDir, this.datasetName, this.resampleId);
                }
                double[] cvAccAndPreds = ElasticEnsemble.isDerivative(this.classifiersToUse[c]) ? this.classifiers[c].loocv(this.derTrain) : this.classifiers[c].loocv(train);
                this.cvAccs[c] = cvAccAndPreds[0];
                for (int i = 0; i < train.numInstances(); ++i) {
                    this.cvPreds[c][i] = cvAccAndPreds[i + 1];
                }
            }
            if (this.writeEnsembleTrainingFile) {
                StringBuilder output = new StringBuilder();
                double[] ensembleCvPreds = this.getEnsembleCvPreds();
                output.append(train.relationName()).append(",EE,train\n");
                output.append(this.getParameters()).append("\n");
                output.append(this.getEnsembleCvAcc()).append("\n");
                for (int i = 0; i < train.numInstances(); ++i) {
                    output.append(train.instance(i).classValue()).append(",").append(ensembleCvPreds[i]).append("\n");
                }
                FileWriter fullTrain = new FileWriter(this.ensembleTrainFilePathAndName);
                fullTrain.append(output);
                fullTrain.close();
            }
        }
        this.trainResults.buildTime = System.currentTimeMillis() - this.trainResults.buildTime;
    }

    public static Efficient1NN getClassifier(ConstituentClassifiers classifier) throws Exception {
        Efficient1NN knn = null;
        switch (classifier) {
            case Euclidean_1NN: {
                return new ED1NN();
            }
            case DTW_R1_1NN: {
                return new DTW1NN(1.0);
            }
            case DDTW_R1_1NN: {
                knn = new DTW1NN(1.0);
                knn.setClassifierIdentifier(classifier.toString());
                return knn;
            }
            case DTW_Rn_1NN: {
                return new DTW1NN();
            }
            case DDTW_Rn_1NN: {
                knn = new DTW1NN();
                knn.setClassifierIdentifier(classifier.toString());
                return knn;
            }
            case WDTW_1NN: {
                return new WDTW1NN();
            }
            case WDDTW_1NN: {
                knn = new WDTW1NN();
                knn.setClassifierIdentifier(classifier.toString());
                return knn;
            }
            case LCSS_1NN: {
                return new LCSS1NN();
            }
            case ERP_1NN: {
                return new ERP1NN();
            }
            case MSM_1NN: {
                return new MSM1NN();
            }
            case TWE_1NN: {
                return new TWE1NN();
            }
        }
        throw new Exception("Unsupported classifier type");
    }

    @Override
    public double classifyInstance(Instance instance) throws Exception {
        if (this.classifiers == null) {
            throw new Exception("Error: classifier not built");
        }
        Instance derIns = null;
        if (this.usesDer) {
            Instances temp = new Instances(this.derTrain, 1);
            temp.add(instance);
            temp = df.process(temp);
            derIns = temp.instance(0);
        }
        double bsfVote = -1.0;
        double[] classTotals = new double[this.train.numClasses()];
        ArrayList<Double> bsfClassVal = null;
        this.previousPredictions = new double[this.classifiers.length];
        for (int c = 0; c < this.classifiers.length; ++c) {
            double pred = ElasticEnsemble.isDerivative(this.classifiersToUse[c]) ? this.classifiers[c].classifyInstance(derIns) : this.classifiers[c].classifyInstance(instance);
            this.previousPredictions[c] = pred;
            try {
                int n = (int)pred;
                classTotals[n] = classTotals[n] + this.cvAccs[c];
            }
            catch (Exception e) {
                System.out.println("cv accs " + this.cvAccs.length);
                System.out.println(pred);
                throw e;
            }
            if (classTotals[(int)pred] > bsfVote) {
                bsfClassVal = new ArrayList<Double>();
                bsfClassVal.add(pred);
                bsfVote = classTotals[(int)pred];
                continue;
            }
            if (classTotals[(int)pred] != bsfVote) continue;
            bsfClassVal.add(pred);
        }
        if (bsfClassVal.size() > 1) {
            return (Double)bsfClassVal.get(new Random(46L).nextInt(bsfClassVal.size()));
        }
        return (Double)bsfClassVal.get(0);
    }

    public double[] classifyInstanceByConstituents(Instance instance) throws Exception {
        Instance ins = instance;
        double[] predsByClassifier = new double[this.classifiers.length];
        for (int i = 0; i < this.classifiers.length; ++i) {
            predsByClassifier[i] = this.classifiers[i].classifyInstance(ins);
        }
        return predsByClassifier;
    }

    public double[] getPreviousPredictions() throws Exception {
        if (this.previousPredictions == null) {
            throw new Exception("Error: no previous instance found");
        }
        return this.previousPredictions;
    }

    @Override
    public double[] distributionForInstance(Instance instance) throws Exception {
        int c;
        if (this.classifiers == null) {
            throw new Exception("Error: classifier not built");
        }
        Instance derIns = null;
        if (this.usesDer) {
            Instances temp = new Instances(this.derTrain, 1);
            temp.add(instance);
            temp = df.process(temp);
            derIns = temp.instance(0);
        }
        double[] classTotals = new double[this.train.numClasses()];
        double cvSum = 0.0;
        for (c = 0; c < this.classifiers.length; ++c) {
            double pred = ElasticEnsemble.isDerivative(this.classifiersToUse[c]) ? this.classifiers[c].classifyInstance(derIns) : this.classifiers[c].classifyInstance(instance);
            try {
                int n = (int)pred;
                classTotals[n] = classTotals[n] + this.cvAccs[c];
            }
            catch (Exception e) {
                System.out.println("cv accs " + this.cvAccs.length);
                System.out.println(pred);
                throw e;
            }
            cvSum += this.cvAccs[c];
        }
        c = 0;
        while (c < classTotals.length) {
            int n = c++;
            classTotals[n] = classTotals[n] / cvSum;
        }
        return classTotals;
    }

    public double[] getCVAccs() throws Exception {
        if (this.cvAccs == null) {
            throw new Exception("Error: classifier not built yet");
        }
        return this.cvAccs;
    }

    private String getClassifierInfo() {
        StringBuilder st = new StringBuilder();
        st.append("EE using:\n");
        st.append("=====================\n");
        for (int c = 0; c < this.classifiers.length; ++c) {
            st.append((Object)this.classifiersToUse[c]).append(" ").append(this.classifiers[c].getClassifierIdentifier()).append(" ").append(this.cvAccs[c]).append("\n");
        }
        return st.toString();
    }

    @Override
    public String getParameters() {
        StringBuilder params = new StringBuilder();
        params.append(super.getParameters()).append(",");
        for (int c = 0; c < this.classifiers.length; ++c) {
            params.append(this.classifiers[c].getClassifierIdentifier()).append(",").append(this.classifiers[c].getParamInformationString()).append(",");
        }
        return params.toString();
    }

    public String toString() {
        return super.toString() + "\n" + this.getClassifierInfo();
    }

    public static void exampleUsage(String datasetName, int resampeId, String outputResultsDirName) throws Exception {
        System.out.println("to do");
    }

    public static void main(String[] args) throws Exception {
        ElasticEnsemble ee = new ElasticEnsemble();
        Instances train = ClassifierTools.loadData("C:/users/sjx07ngu/dropbox/tsc problems/ItalyPowerDemand/ItalyPowerDemand_TRAIN");
        Instances test = ClassifierTools.loadData("C:/users/sjx07ngu/dropbox/tsc problems/ItalyPowerDemand/ItalyPowerDemand_TEST");
        ee.buildClassifier(train);
        int correct = 0;
        for (int i = 0; i < test.numInstances(); ++i) {
            if (test.instance(i).classValue() != ee.classifyInstance(test.instance(i))) continue;
            ++correct;
        }
        System.out.println("correct: " + correct + "/" + test.numInstances());
        System.out.println((double)correct / (double)test.numInstances());
        System.out.println(ee.getEnsembleCvAcc());
    }

    public static enum ConstituentClassifiers {
        Euclidean_1NN,
        DTW_R1_1NN,
        DTW_Rn_1NN,
        WDTW_1NN,
        DDTW_R1_1NN,
        DDTW_Rn_1NN,
        WDDTW_1NN,
        LCSS_1NN,
        MSM_1NN,
        TWE_1NN,
        ERP_1NN;

    }
}

