/*
 * Decompiled with CFR 0.152.
 */
package timeseriesweka.classifiers;

import java.io.File;
import java.io.FileWriter;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Random;
import java.util.Scanner;
import timeseriesweka.classifiers.AbstractClassifierWithTrainingData;
import timeseriesweka.classifiers.BOSS;
import timeseriesweka.classifiers.ElasticEnsemble;
import timeseriesweka.classifiers.RISE;
import timeseriesweka.classifiers.TSF;
import timeseriesweka.classifiers.cote.HiveCoteModule;
import timeseriesweka.filters.shapelet_transforms.ShapeletTransform;
import timeseriesweka.filters.shapelet_transforms.ShapeletTransformTimingUtilities;
import utilities.ClassifierTools;
import vector_classifiers.HESCA;
import weka.classifiers.Classifier;
import weka.core.Instance;
import weka.core.Instances;

public class HiveCote
extends AbstractClassifierWithTrainingData {
    private ArrayList<Classifier> classifiers;
    private ArrayList<String> names;
    private ConstituentHiveEnsemble[] modules;
    private boolean verbose = false;
    private int maxCvFolds = 10;
    private boolean fileWriting = false;
    private String fileOutputDir;
    private String fileOutputDataset;
    private String fileOutputResampleId;

    public HiveCote() {
        this.setDefaultEnsembles();
    }

    public HiveCote(ArrayList<Classifier> classifiers, ArrayList<String> classifierNames) {
        this.classifiers = classifiers;
        this.names = classifierNames;
    }

    private void setDefaultEnsembles() {
        this.classifiers = new ArrayList();
        this.names = new ArrayList();
        this.classifiers.add(new ElasticEnsemble());
        HESCA h = new HESCA();
        h.setTransform(new DefaultShapeletTransformPlaceholder());
        this.classifiers.add(h);
        RISE rise = new RISE();
        rise.setTransformType(RISE.Filter.PS_ACF);
        this.classifiers.add(rise);
        this.classifiers.add(new BOSS());
        this.classifiers.add(new TSF());
        this.names.add("EE");
        this.names.add("ST");
        this.names.add("RISE");
        this.names.add("BOSS");
        this.names.add("TSF");
    }

    public void turnOnFileWriting(String outputDir, String datasetName) {
        this.turnOnFileWriting(outputDir, datasetName, "0");
    }

    public void turnOnFileWriting(String outputDir, String datasetName, String resampleFoldIdentifier) {
        this.fileWriting = true;
        this.fileOutputDir = outputDir;
        this.fileOutputDataset = datasetName;
        this.fileOutputResampleId = resampleFoldIdentifier;
    }

    @Override
    public void buildClassifier(Instances train) throws Exception {
        this.trainResults.buildTime = System.currentTimeMillis();
        this.optionalOutputLine("Start of training");
        this.modules = new ConstituentHiveEnsemble[this.classifiers.size()];
        System.out.println("modules include:");
        for (int i = 0; i < this.classifiers.size(); ++i) {
            System.out.println(this.names.get(i));
        }
        for (int i = 0; i < this.classifiers.size(); ++i) {
            if (this.classifiers.get(i) instanceof HESCA && ((HESCA)this.classifiers.get(i)).getTransform() instanceof DefaultShapeletTransformPlaceholder) {
                this.classifiers.remove(i);
                ShapeletTransform shoutyThing = ShapeletTransformTimingUtilities.createTransformWithTimeLimit(train, 24.0);
                shoutyThing.supressOutput();
                HESCA h = new HESCA();
                h.setTransform(shoutyThing);
                this.classifiers.add(i, h);
            }
            if (this.classifiers.get(i) instanceof HiveCoteModule) {
                this.optionalOutputLine("training (group a): " + this.names.get(i));
                this.classifiers.get(i).buildClassifier(train);
                this.modules[i] = new ConstituentHiveEnsemble(this.names.get(i), this.classifiers.get(i), ((HiveCoteModule)((Object)this.classifiers.get(i))).getEnsembleCvAcc());
                if (this.fileWriting) {
                    String outputFilePathAndName = this.fileOutputDir + this.names.get(i) + "/Predictions/" + this.fileOutputDataset + "/trainFold" + this.fileOutputResampleId + ".csv";
                    HiveCote.genericCvResultsFileWriter(outputFilePathAndName, train, ((HiveCoteModule)((Object)this.modules[i].classifier)).getEnsembleCvPreds(), this.fileOutputDataset, this.modules[i].classifierName, ((HiveCoteModule)((Object)this.modules[i].classifier)).getParameters(), this.modules[i].ensembleCvAcc);
                }
            } else {
                this.optionalOutputLine("crossval (group b): " + this.names.get(i));
                double ensembleAcc = this.crossValidateWithFileWriting(this.classifiers.get(i), train, this.maxCvFolds, this.names.get(i));
                this.optionalOutputLine("training (group b): " + this.names.get(i));
                this.classifiers.get(i).buildClassifier(train);
                this.modules[i] = new ConstituentHiveEnsemble(this.names.get(i), this.classifiers.get(i), ensembleAcc);
            }
            this.optionalOutputLine("done " + this.modules[i].classifierName);
        }
        if (this.verbose) {
            this.printModuleCvAccs();
        }
        this.trainResults.buildTime = System.currentTimeMillis() - this.trainResults.buildTime;
    }

    private static void genericCvResultsFileWriter(String outFilePathAndName, Instances instances, String classifierName, double[] preds, double cvAcc) throws Exception {
        HiveCote.genericCvResultsFileWriter(outFilePathAndName, instances, preds, instances.relationName(), classifierName, "noParamInfo", cvAcc);
    }

    private static void genericCvResultsFileWriter(String outFilePathAndName, Instances instances, double[] preds, String datasetName, String classifierName, String paramInfo, double cvAcc) throws Exception {
        if (instances.numInstances() != preds.length) {
            throw new Exception("Error: num instances doesn't match num preds.");
        }
        File outPath = new File(outFilePathAndName);
        outPath.getParentFile().mkdirs();
        FileWriter out = new FileWriter(outFilePathAndName);
        out.append(datasetName + "," + classifierName + ",train\n");
        out.append(paramInfo + "\n");
        out.append(cvAcc + "\n");
        for (int i = 0; i < instances.numInstances(); ++i) {
            out.append(instances.instance(i).classValue() + "," + preds[i] + "\n");
        }
        out.close();
    }

    @Override
    public double[] distributionForInstance(Instance instance) throws Exception {
        return this.distributionForInstance(instance, null);
    }

    private double[] distributionForInstance(Instance instance, StringBuilder[] outputFileBuilders) throws Exception {
        double bsfClassWeight;
        double bsfClassVal;
        StringBuilder moduleString;
        if (outputFileBuilders != null && outputFileBuilders.length != this.modules.length + 1) {
            throw new Exception("Error: to write test files, there must be m+1 output StringBuilders (where m is the number of modules)");
        }
        double[] hiveDists = new double[instance.numClasses()];
        double cvAccSum = 0.0;
        for (int m = 0; m < this.modules.length; ++m) {
            double[] moduleDists = this.modules[m].classifier.distributionForInstance(instance);
            moduleString = new StringBuilder();
            double moduleWeight = this.modules[m].ensembleCvAcc;
            bsfClassVal = -1.0;
            bsfClassWeight = -1.0;
            for (int c = 0; c < hiveDists.length; ++c) {
                int n = c;
                hiveDists[n] = hiveDists[n] + moduleDists[c] * moduleWeight;
                if (outputFileBuilders == null) continue;
                if (moduleDists[c] > bsfClassWeight) {
                    bsfClassWeight = moduleDists[c];
                    bsfClassVal = c;
                }
                moduleString.append(",").append(moduleDists[c]);
            }
            if (outputFileBuilders != null) {
                outputFileBuilders[m].append(instance.classValue()).append(",").append(bsfClassVal).append(",").append(moduleString.toString() + "\n");
            }
            cvAccSum += this.modules[m].ensembleCvAcc;
        }
        int h = 0;
        while (h < hiveDists.length) {
            int n = h++;
            hiveDists[n] = hiveDists[n] / cvAccSum;
        }
        if (outputFileBuilders != null) {
            bsfClassVal = -1.0;
            bsfClassWeight = -1.0;
            moduleString = new StringBuilder();
            for (int c = 0; c < hiveDists.length; ++c) {
                if (hiveDists[c] > bsfClassWeight) {
                    bsfClassWeight = hiveDists[c];
                    bsfClassVal = c;
                }
                moduleString.append(",").append(hiveDists[c]);
            }
            outputFileBuilders[outputFileBuilders.length - 1].append(instance.classValue()).append(",").append(bsfClassVal).append(",").append(moduleString.toString() + "\n");
        }
        return hiveDists;
    }

    public double[] classifyInstanceByEnsemble(Instance instance) throws Exception {
        double[] output = new double[this.modules.length];
        for (int m = 0; m < this.modules.length; ++m) {
            output[m] = this.modules[m].classifier.classifyInstance(instance);
        }
        return output;
    }

    public void printModuleCvAccs() throws Exception {
        if (this.modules == null) {
            throw new Exception("Error: modules don't exist. Train classifier first.");
        }
        System.out.println("CV accs by module:");
        System.out.println("------------------");
        StringBuilder line1 = new StringBuilder();
        StringBuilder line2 = new StringBuilder();
        for (ConstituentHiveEnsemble module : this.modules) {
            line1.append(module.classifierName).append(",");
            line2.append(module.ensembleCvAcc).append(",");
        }
        System.out.println(line1);
        System.out.println(line2);
        System.out.println();
    }

    public void makeShouty() {
        this.verbose = true;
    }

    private void optionalOutputLine(String message) {
        if (this.verbose) {
            System.out.println(message);
        }
    }

    public void setMaxCvFolds(int maxFolds) {
        this.maxCvFolds = maxFolds;
    }

    public void writeTestPredictionsToFile(Instances test, String outputDir, String datasetName) throws Exception {
        this.writeTestPredictionsToFile(test, outputDir, datasetName, "0");
    }

    public void writeTestPredictionsToFile(Instances test, String outputDir, String datasetName, String datasetResampleIdentifier) throws Exception {
        String[] lineParts;
        Scanner scan;
        int correct;
        File dir;
        int m;
        this.fileOutputDir = outputDir;
        this.fileOutputDataset = datasetName;
        this.fileOutputResampleId = datasetResampleIdentifier;
        StringBuilder[] outputs = new StringBuilder[this.modules.length + 1];
        for (int m2 = 0; m2 < outputs.length; ++m2) {
            outputs[m2] = new StringBuilder();
        }
        for (int i = 0; i < test.numInstances(); ++i) {
            this.distributionForInstance(test.instance(i), outputs);
        }
        for (m = 0; m < this.modules.length; ++m) {
            dir = new File(this.fileOutputDir + this.modules[m].classifierName + "/Predictions/" + this.fileOutputDataset + "/");
            if (!dir.exists()) {
                dir.mkdirs();
            }
            correct = 0;
            scan = new Scanner(outputs[m].toString());
            scan.useDelimiter("\n");
            while (scan.hasNext()) {
                lineParts = scan.next().split(",");
                if (!lineParts[0].trim().equalsIgnoreCase(lineParts[1].trim())) continue;
                ++correct;
            }
            scan.close();
            FileWriter out = new FileWriter(this.fileOutputDir + this.modules[m].classifierName + "/Predictions/" + this.fileOutputDataset + "/testFold" + this.fileOutputResampleId + ".csv");
            out.append(this.fileOutputDataset + "," + this.modules[m].classifierName + ",test\n");
            out.append("builtInHive\n");
            out.append((double)correct / (double)test.numInstances() + "\n");
            out.append(outputs[m]);
            out.close();
        }
        correct = 0;
        scan = new Scanner(outputs[outputs.length - 1].toString());
        scan.useDelimiter("\n");
        while (scan.hasNext()) {
            lineParts = scan.next().split(",");
            if (!lineParts[0].trim().equalsIgnoreCase(lineParts[1].trim())) continue;
            ++correct;
        }
        scan.close();
        dir = new File(this.fileOutputDir + "HIVE-COTE/Predictions/" + this.fileOutputDataset + "/");
        if (!dir.exists()) {
            dir.mkdirs();
        }
        FileWriter out = new FileWriter(this.fileOutputDir + "HIVE-COTE/Predictions/" + this.fileOutputDataset + "/testFold" + this.fileOutputResampleId + ".csv");
        out.append(this.fileOutputDataset + ",HIVE-COTE,test\nconstituentCvAccs,");
        for (m = 0; m < this.modules.length; ++m) {
            out.append(this.modules[m].classifierName + "," + this.modules[m].ensembleCvAcc + ",");
        }
        out.append("\n" + (double)correct / (double)test.numInstances() + "\n");
        out.append("\n" + outputs[outputs.length - 1]);
        out.close();
    }

    public double crossValidate(Classifier classifier, Instances train, int maxFolds) throws Exception {
        return this.crossValidateWithFileWriting(classifier, train, maxFolds, null);
    }

    public double crossValidateWithFileWriting(Classifier classifier, Instances train, int maxFolds, String classifierName) throws Exception {
        int numFolds = maxFolds;
        if (numFolds <= 1 || numFolds > train.numInstances()) {
            numFolds = train.numInstances();
        }
        Random r = new Random();
        ArrayList<Instances> folds = new ArrayList<Instances>();
        ArrayList foldIndexing = new ArrayList();
        for (int i = 0; i < numFolds; ++i) {
            folds.add(new Instances(train, 0));
            foldIndexing.add(new ArrayList());
        }
        ArrayList<Integer> instanceIds = new ArrayList<Integer>();
        for (int i = 0; i < train.numInstances(); ++i) {
            instanceIds.add(i);
        }
        Collections.shuffle(instanceIds, r);
        ArrayList<Instances> byClass = new ArrayList<Instances>();
        ArrayList byClassIndices = new ArrayList();
        for (int i = 0; i < train.numClasses(); ++i) {
            byClass.add(new Instances(train, 0));
            byClassIndices.add(new ArrayList());
        }
        for (int i = 0; i < train.numInstances(); ++i) {
            int thisInstanceId = (Integer)instanceIds.get(i);
            double thisClassVal = train.instance(thisInstanceId).classValue();
            ((Instances)byClass.get((int)thisClassVal)).add(train.instance(thisInstanceId));
            ((ArrayList)byClassIndices.get((int)thisClassVal)).add(thisInstanceId);
        }
        Instances strat = new Instances(train, 0);
        ArrayList stratIndices = new ArrayList();
        int stratCount = 0;
        int[] classCounters = new int[train.numClasses()];
        while (stratCount < train.numInstances()) {
            for (int c = 0; c < train.numClasses(); ++c) {
                if (classCounters[c] >= ((Instances)byClass.get(c)).size()) continue;
                strat.add(((Instances)byClass.get(c)).instance(classCounters[c]));
                stratIndices.add(((ArrayList)byClassIndices.get(c)).get(classCounters[c]));
                int n = c;
                classCounters[n] = classCounters[n] + 1;
                ++stratCount;
            }
        }
        train = strat;
        instanceIds = stratIndices;
        double foldSize = (double)train.numInstances() / (double)numFolds;
        double thisSum = 0.0;
        double lastSum = 0.0;
        int foldSum = 0;
        int currentStart = 0;
        for (int f = 0; f < numFolds; ++f) {
            thisSum = lastSum + foldSize + 1.0E-12;
            int floor = (int)thisSum;
            if (f == numFolds - 1) {
                floor = train.numInstances();
            }
            for (int i = currentStart; i < floor; ++i) {
                ((Instances)folds.get(f)).add(train.instance(i));
                ((ArrayList)foldIndexing.get(f)).add(instanceIds.get(i));
            }
            foldSum += floor - currentStart;
            currentStart = floor;
            lastSum = thisSum;
        }
        if (foldSum != train.numInstances()) {
            throw new Exception("Error! Some instances got lost file creating folds (maybe a double precision bug). Training instances contains " + train.numInstances() + ", but the sum of the training folds is " + foldSum);
        }
        double[] predictions = new double[train.numInstances()];
        int correct = 0;
        for (int testFold = 0; testFold < numFolds; ++testFold) {
            Instances trainLoocv = null;
            Instances testLoocv = new Instances((Instances)folds.get(testFold));
            for (int f = 0; f < numFolds; ++f) {
                if (f == testFold) continue;
                Instances temp = new Instances((Instances)folds.get(f));
                if (trainLoocv == null) {
                    trainLoocv = temp;
                    continue;
                }
                trainLoocv.addAll(temp);
            }
            classifier.buildClassifier(trainLoocv);
            for (int i = 0; i < testLoocv.numInstances(); ++i) {
                double pred = classifier.classifyInstance(testLoocv.instance(i));
                double actual = testLoocv.instance(i).classValue();
                predictions[((Integer)((ArrayList)foldIndexing.get((int)testFold)).get((int)i)).intValue()] = pred;
                if (pred != actual) continue;
                ++correct;
            }
        }
        double cvAcc = (double)correct / (double)train.numInstances();
        if (this.fileWriting) {
            String outputFilePathAndName = this.fileOutputDir + classifierName + "/Predictions/" + this.fileOutputDataset + "/trainFold" + this.fileOutputResampleId + ".csv";
            HiveCote.genericCvResultsFileWriter(outputFilePathAndName, train, predictions, this.fileOutputDataset, classifierName, "genericInternalCv,numFolds," + numFolds, cvAcc);
        }
        return cvAcc;
    }

    public static void main(String[] args) throws Exception {
        String datasetName = "ItalyPowerDemand";
        Instances train = ClassifierTools.loadData("C:/users/sjx07ngu/dropbox/tsc problems/" + datasetName + "/" + datasetName + "_TRAIN");
        Instances test = ClassifierTools.loadData("C:/users/sjx07ngu/dropbox/tsc problems/" + datasetName + "/" + datasetName + "_TEST");
        HiveCote hive = new HiveCote();
        hive.makeShouty();
        hive.buildClassifier(train);
        hive.writeTestPredictionsToFile(test, "prototypeSheets/", datasetName, "0");
        int correct = 0;
        int[] correctByEnsemble = new int[hive.modules.length];
        for (int i = 0; i < test.numInstances(); ++i) {
            if (hive.classifyInstance(test.instance(i)) == test.instance(i).classValue()) {
                ++correct;
            }
            double[] predByEnsemble = hive.classifyInstanceByEnsemble(test.instance(i));
            for (int m = 0; m < predByEnsemble.length; ++m) {
                if (predByEnsemble[m] != test.instance(i).classValue()) continue;
                int n = m;
                correctByEnsemble[n] = correctByEnsemble[n] + 1;
            }
        }
        System.out.println("Overall Acc: " + (double)correct / (double)test.numInstances());
        System.out.println("Acc by Module:");
        StringBuilder line1 = new StringBuilder();
        StringBuilder line2 = new StringBuilder();
        for (int m = 0; m < hive.modules.length; ++m) {
            line1.append(hive.modules[m].classifierName).append(",");
            line2.append((double)correctByEnsemble[m] / (double)test.numInstances()).append(",");
        }
        System.out.println(line1);
        System.out.println(line2);
    }

    public static class DefaultShapeletTransformPlaceholder
    extends ShapeletTransform {
    }

    private class ConstituentHiveEnsemble {
        public final Classifier classifier;
        public final double ensembleCvAcc;
        public final String classifierName;

        public ConstituentHiveEnsemble(String classifierName, Classifier classifier, double ensembleCvAcc) {
            this.classifierName = classifierName;
            this.classifier = classifier;
            this.ensembleCvAcc = ensembleCvAcc;
        }
    }
}

