/*
 * Decompiled with CFR 0.152.
 */
package timeseriesweka.classifiers.cote;

import development.DataSets;
import java.util.ArrayList;
import timeseriesweka.classifiers.cote.AbstractPostProcessedCote;

public class HiveCotePostProcessed
extends AbstractPostProcessedCote {
    private double alpha = 1.0;
    private boolean useVoting = false;

    public HiveCotePostProcessed(String resultsDir, String datasetName, int resampleId, ArrayList<String> classifierNames) {
        CLASSIFIER_NAME = "HIVE-COTE";
        this.resultsDir = resultsDir;
        this.datasetName = datasetName;
        this.resampleId = resampleId;
        this.classifierNames = classifierNames;
    }

    public HiveCotePostProcessed(String resultsDir, String datasetName, ArrayList<String> classifierNames) {
        CLASSIFIER_NAME = "HIVE-COTE";
        this.resultsDir = resultsDir;
        this.datasetName = datasetName;
        this.resampleId = 0;
        this.classifierNames = classifierNames;
    }

    public HiveCotePostProcessed(String resultsDir, String datasetName, int resampleId) {
        CLASSIFIER_NAME = "HIVE-COTE";
        this.resultsDir = resultsDir;
        this.datasetName = datasetName;
        this.resampleId = resampleId;
        this.classifierNames = this.getDefaultClassifierNames();
    }

    public HiveCotePostProcessed(String resultsDir, String datasetName) {
        CLASSIFIER_NAME = "HIVE-COTE";
        this.resultsDir = resultsDir;
        this.datasetName = datasetName;
        this.resampleId = 0;
        this.classifierNames = this.getDefaultClassifierNames();
    }

    public void setAlpha(double alpha) {
        this.alpha = alpha;
    }

    private void useVotes() {
        this.useVoting = true;
    }

    private void useProbs() {
        this.useVoting = false;
    }

    private ArrayList<String> getDefaultClassifierNames() {
        ArrayList<String> names = new ArrayList<String>();
        names.add("EE");
        names.add("ST");
        names.add("RISE");
        names.add("BOSS");
        names.add("TSF");
        return names;
    }

    @Override
    public double[] distributionForInstance(int testInstanceId) throws Exception {
        if (this.useVoting) {
            return this.distributionForInstanceWithVoting(testInstanceId);
        }
        return this.distributionForInstanceWithProbs(testInstanceId);
    }

    public double[] distributionForInstanceWithProbs(int testInstanceId) throws Exception {
        if (this.testDists == null) {
            throw new Exception("Error: classifier not initialised correctly. Load results before classifiying.");
        }
        int numClasses = this.testDists[0][0].length;
        double[] outDist = new double[numClasses];
        double cvAccSum = 0.0;
        for (int classifier = 0; classifier < this.testDists.length; ++classifier) {
            for (int classVal = 0; classVal < numClasses; ++classVal) {
                int n = classVal;
                outDist[n] = outDist[n] + this.testDists[classifier][testInstanceId][classVal] * Math.pow(this.cvAccs[classifier], this.alpha);
            }
            cvAccSum += Math.pow(this.cvAccs[classifier], this.alpha);
        }
        int classVal = 0;
        while (classVal < numClasses) {
            int n = classVal++;
            outDist[n] = outDist[n] / cvAccSum;
        }
        return outDist;
    }

    public double[] distributionForInstanceWithVoting(int testInstanceId) throws Exception {
        if (this.testDists == null) {
            throw new Exception("Error: classifier not initialised correctly. Load results before classifiying.");
        }
        int numClasses = this.testDists[0][0].length;
        double[] outDist = new double[numClasses];
        double cvAccSum = 0.0;
        for (int classifier = 0; classifier < this.testDists.length; ++classifier) {
            int maxId = -1;
            double bsfWeight = -1.0;
            for (int classVal = 0; classVal < numClasses; ++classVal) {
                if (!(this.testDists[classifier][testInstanceId][classVal] > bsfWeight)) continue;
                maxId = classVal;
                bsfWeight = this.testDists[classifier][testInstanceId][classVal];
            }
            int n = maxId;
            outDist[n] = outDist[n] + Math.pow(this.cvAccs[classifier], this.alpha);
            cvAccSum += Math.pow(this.cvAccs[classifier], this.alpha);
        }
        int classVal = 0;
        while (classVal < numClasses) {
            int n = classVal++;
            outDist[n] = outDist[n] / cvAccSum;
        }
        return outDist;
    }

    public static void main(String[] args) throws Exception {
        double[] alphas = new double[]{1.0, 4.0};
        ArrayList<String> classifiersToUse = new ArrayList<String>();
        classifiersToUse.add("EE_proto");
        classifiersToUse.add("ST_HiveProto");
        classifiersToUse.add("RISE");
        classifiersToUse.add("BOSS");
        classifiersToUse.add("TSF");
        System.out.println("votes");
        for (double alpha : alphas) {
            for (String datasetName : DataSets.fileNames) {
                System.out.println(datasetName + " " + alpha);
                for (int resample = 0; resample < 100; ++resample) {
                    try {
                        HiveCotePostProcessed hcpp = new HiveCotePostProcessed("C:/3xsshare/Jay/LocalWork/coteConstituentResults/", datasetName, resample, classifiersToUse);
                        hcpp.setAlpha(alpha);
                        hcpp.useVotes();
                        hcpp.writeTestSheet("hiveWritingProtoRewrite_alpha" + alpha + "_votes/");
                        continue;
                    }
                    catch (Exception e) {
                        System.err.println(datasetName + "_" + resample + "_" + alpha);
                    }
                }
            }
        }
    }
}

