/*
 * Decompiled with CFR 0.152.
 */
package vector_classifiers;

import development.CollateResults;
import development.DataSets;
import development.Experiments;
import development.MultipleClassifierEvaluation;
import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.util.Arrays;
import timeseriesweka.classifiers.cote.HiveCoteModule;
import timeseriesweka.classifiers.ensembles.EnsembleModule;
import timeseriesweka.classifiers.ensembles.voting.MajorityConfidence;
import timeseriesweka.classifiers.ensembles.voting.MajorityVote;
import timeseriesweka.classifiers.ensembles.voting.ModuleVotingScheme;
import timeseriesweka.classifiers.ensembles.weightings.ModuleWeightingScheme;
import timeseriesweka.classifiers.ensembles.weightings.TrainAcc;
import timeseriesweka.classifiers.ensembles.weightings.TrainAccByClass;
import timeseriesweka.filters.SAX;
import utilities.ClassifierResults;
import utilities.ClassifierTools;
import utilities.CrossValidator;
import utilities.DebugPrinting;
import utilities.ErrorReport;
import utilities.GenericTools;
import utilities.InstanceTools;
import utilities.SaveParameterInfo;
import utilities.StatisticalUtilities;
import utilities.TrainAccuracyEstimate;
import weka.classifiers.AbstractClassifier;
import weka.classifiers.Classifier;
import weka.classifiers.bayes.BayesNet;
import weka.classifiers.bayes.NaiveBayes;
import weka.classifiers.functions.Logistic;
import weka.classifiers.functions.MultilayerPerceptron;
import weka.classifiers.functions.SMO;
import weka.classifiers.functions.supportVector.PolyKernel;
import weka.classifiers.lazy.kNN;
import weka.classifiers.meta.RotationForest;
import weka.classifiers.trees.J48;
import weka.classifiers.trees.RandomForest;
import weka.core.EuclideanDistance;
import weka.core.Instance;
import weka.core.Instances;
import weka.filters.SimpleBatchFilter;

public class CAWPE
extends AbstractClassifier
implements HiveCoteModule,
SaveParameterInfo,
DebugPrinting,
TrainAccuracyEstimate {
    protected ModuleWeightingScheme weightingScheme = new TrainAcc(4.0);
    protected ModuleVotingScheme votingScheme = new MajorityConfidence();
    protected EnsembleModule[] modules;
    protected boolean setSeed = true;
    protected int seed = 0;
    protected SimpleBatchFilter transform = null;
    protected Instances train;
    protected boolean writeEnsembleTrainingFile = false;
    protected String outputEnsembleTrainingPathAndFile;
    protected boolean performEnsembleCV = true;
    protected CrossValidator cv = null;
    protected ClassifierResults ensembleTrainResults = null;
    protected ClassifierResults ensembleTestResults = null;
    protected int numTrainInsts;
    protected int numAttributes;
    protected int numClasses;
    protected int testInstCounter = 0;
    protected int numTestInsts = -1;
    protected Instance prevTestInstance = null;
    protected boolean readIndividualsResults = false;
    protected boolean writeIndividualsResults = false;
    protected boolean resultsFilesParametersInitialised;
    protected String[] readResultsFilesDirectories = null;
    protected String writeResultsFilesDirectory = null;
    protected String ensembleIdentifier = "CAWPE";
    protected int resampleIdentifier;
    protected String datasetName;

    public CAWPE() {
        this.setDefaultCAWPESettings();
    }

    public Classifier[] getClassifiers() {
        Classifier[] classifiers = new Classifier[this.modules.length];
        for (int i = 0; i < this.modules.length; ++i) {
            classifiers[i] = this.modules[i].getClassifier();
        }
        return classifiers;
    }

    public void setClassifiers(Classifier[] classifiers, String[] classifierNames, String[] classifierParameters) {
        int i;
        if (classifiers == null) {
            classifiers = new Classifier[classifierNames.length];
            for (i = 0; i < classifiers.length; ++i) {
                classifiers[i] = null;
            }
        }
        if (classifierNames == null) {
            classifierNames = new String[classifiers.length];
            for (i = 0; i < classifiers.length; ++i) {
                classifierNames[i] = classifiers[i].getClass().getSimpleName();
            }
        }
        if (classifierParameters == null) {
            classifierParameters = new String[classifiers.length];
            for (i = 0; i < classifiers.length; ++i) {
                classifierParameters[i] = "";
            }
        }
        this.modules = new EnsembleModule[classifiers.length];
        for (int m = 0; m < this.modules.length; ++m) {
            this.modules[m] = new EnsembleModule(classifierNames[m], classifiers[m], classifierParameters[m]);
        }
    }

    public final void setOriginalHESCASettings() {
        this.weightingScheme = new TrainAcc();
        this.votingScheme = new MajorityVote();
        Classifier[] classifiers = new Classifier[8];
        String[] classifierNames = new String[8];
        kNN k = new kNN(100);
        k.setCrossValidate(true);
        k.normalise(false);
        k.setDistanceFunction(new EuclideanDistance());
        classifiers[0] = k;
        classifierNames[0] = "NN";
        classifiers[1] = new NaiveBayes();
        classifierNames[1] = "NB";
        classifiers[2] = new J48();
        classifierNames[2] = "C4.5";
        SMO svml = new SMO();
        svml.turnChecksOff();
        PolyKernel kl = new PolyKernel();
        kl.setExponent(1.0);
        svml.setKernel(kl);
        if (this.setSeed) {
            svml.setRandomSeed(this.seed);
        }
        classifiers[3] = svml;
        classifierNames[3] = "SVML";
        SMO svmq = new SMO();
        svmq.turnChecksOff();
        PolyKernel kq = new PolyKernel();
        kq.setExponent(2.0);
        svmq.setKernel(kq);
        if (this.setSeed) {
            svmq.setRandomSeed(this.seed);
        }
        classifiers[4] = svmq;
        classifierNames[4] = "SVMQ";
        RandomForest r = new RandomForest();
        r.setNumTrees(500);
        if (this.setSeed) {
            r.setSeed(this.seed);
        }
        classifiers[5] = r;
        classifierNames[5] = "RandF";
        RotationForest rf = new RotationForest();
        rf.setNumIterations(50);
        if (this.setSeed) {
            rf.setSeed(this.seed);
        }
        classifiers[6] = rf;
        classifierNames[6] = "RotF";
        classifiers[7] = new BayesNet();
        classifierNames[7] = "bayesNet";
        this.setClassifiers(classifiers, classifierNames, null);
    }

    public final void setDefaultCAWPESettings() {
        this.weightingScheme = new TrainAcc(4.0);
        this.votingScheme = new MajorityConfidence();
        Classifier[] classifiers = new Classifier[5];
        String[] classifierNames = new String[5];
        SMO smo = new SMO();
        smo.turnChecksOff();
        smo.setBuildLogisticModels(true);
        PolyKernel kl = new PolyKernel();
        kl.setExponent(1.0);
        smo.setKernel(kl);
        if (this.setSeed) {
            smo.setRandomSeed(this.seed);
        }
        classifiers[0] = smo;
        classifierNames[0] = "SVML";
        kNN k = new kNN(100);
        k.setCrossValidate(true);
        k.normalise(false);
        k.setDistanceFunction(new EuclideanDistance());
        classifiers[1] = k;
        classifierNames[1] = "NN";
        classifiers[2] = new J48();
        classifierNames[2] = "C4.5";
        classifiers[3] = new Logistic();
        classifierNames[3] = "Logistic";
        classifiers[4] = new MultilayerPerceptron();
        classifierNames[4] = "MLP";
        this.setClassifiers(classifiers, classifierNames, null);
    }

    public void setPerformCV(boolean b) {
        this.performEnsembleCV = b;
    }

    public void setRandSeed(int seed) {
        this.setSeed = true;
        this.seed = seed;
    }

    public static int findNumFolds(Instances train) {
        return 10;
    }

    @Override
    public void buildClassifier(Instances data) throws Exception {
        this.printlnDebug("**CAWPE TRAIN**");
        if (this.resultsFilesParametersInitialised) {
            if (this.readResultsFilesDirectories.length > 1 && this.readResultsFilesDirectories.length != this.modules.length) {
                throw new Exception("CAWPE.buildClassifier: more than one results path given, but number given does not align with the number of classifiers/modules.");
            }
            if (this.writeResultsFilesDirectory == null) {
                this.writeResultsFilesDirectory = this.readResultsFilesDirectories[0];
            }
        }
        long startTime = System.currentTimeMillis();
        this.train = this.transform == null ? new Instances(data) : this.transform.process(data);
        this.numTrainInsts = this.train.numInstances();
        this.numClasses = this.train.numClasses();
        this.numAttributes = this.train.numAttributes();
        this.initialiseModules();
        this.weightingScheme.defineWeightings(this.modules, this.numClasses);
        this.votingScheme.trainVotingScheme(this.modules, this.numClasses);
        if (this.performEnsembleCV) {
            long buildTime = System.currentTimeMillis() - startTime;
            this.ensembleTrainResults = this.doEnsembleCV(data);
            this.ensembleTrainResults.buildTime = buildTime;
            if (this.writeEnsembleTrainingFile) {
                this.writeEnsembleCVResults(this.train);
            }
        }
        this.testInstCounter = 0;
    }

    protected void writeEnsembleCVResults(Instances data) throws IOException {
        StringBuilder output = new StringBuilder();
        output.append(data.relationName()).append(",").append(this.ensembleIdentifier).append(",train\n");
        output.append(this.getParameters()).append("\n");
        output.append(this.ensembleTrainResults.acc).append("\n");
        double[] predClassVals = this.ensembleTrainResults.getPredClassVals();
        for (int i = 0; i < this.numTrainInsts; ++i) {
            output.append(data.instance(i).classValue()).append(",").append(predClassVals[i]).append(",");
            double[] distForInst = this.ensembleTrainResults.getDistributionForInstance(i);
            for (int j = 0; j < this.numClasses; ++j) {
                output.append(",").append(distForInst[j]);
            }
            output.append("\n");
        }
        new File(this.outputEnsembleTrainingPathAndFile).getParentFile().mkdirs();
        FileWriter fullTrain = new FileWriter(this.outputEnsembleTrainingPathAndFile);
        fullTrain.append(output);
        fullTrain.close();
    }

    protected void initialiseModules() throws Exception {
        if (this.willNeedToDoCV()) {
            int numFolds = this.setNumberOfFolds(this.train);
            this.cv = new CrossValidator();
            if (this.setSeed) {
                this.cv.setSeed(this.seed);
            }
            this.cv.setNumFolds(numFolds);
            this.cv.buildFolds(this.train);
        }
        if (this.readIndividualsResults) {
            if (!this.resultsFilesParametersInitialised) {
                throw new Exception("Trying to load CAWPE modules from file, but parameters for results file reading have not been initialised");
            }
            this.loadModules();
        } else {
            this.trainModules();
        }
        for (int m = 0; m < this.modules.length; ++m) {
            this.modules[m].trainResults.setNumClasses(this.numClasses);
            this.modules[m].trainResults.setNumInstances(this.numTrainInsts);
            this.modules[m].trainResults.findAllStatsOnce();
        }
    }

    protected boolean willNeedToDoCV() {
        if (this.performEnsembleCV) {
            return true;
        }
        for (EnsembleModule m : this.modules) {
            if (m.getClassifier() instanceof TrainAccuracyEstimate) continue;
            return true;
        }
        return false;
    }

    protected void trainModules() throws Exception {
        for (EnsembleModule module : this.modules) {
            if (module.getClassifier() instanceof TrainAccuracyEstimate) {
                module.getClassifier().buildClassifier(this.train);
                module.trainResults = ((TrainAccuracyEstimate)((Object)module.getClassifier())).getTrainResults();
                if (!this.writeIndividualsResults) continue;
                String params = module.getParameters();
                if (module.getClassifier() instanceof SaveParameterInfo) {
                    params = ((SaveParameterInfo)((Object)module.getClassifier())).getParameters();
                }
                this.writeResultsFile(module.getModuleName(), params, module.trainResults, "train");
                this.printlnDebug(module.getModuleName() + " writing train file data gotten through TrainAccuracyEstimate...");
                continue;
            }
            this.printlnDebug(module.getModuleName() + " performing cv...");
            module.trainResults = this.cv.crossValidateWithStats(module.getClassifier(), this.train);
            long startTime = System.currentTimeMillis();
            module.getClassifier().buildClassifier(this.train);
            module.trainResults.buildTime = System.currentTimeMillis() - startTime;
            module.setParameters("BuildTime," + module.trainResults.buildTime + "," + module.getParameters());
            if (!this.writeIndividualsResults) continue;
            this.writeResultsFile(module.getModuleName(), module.getParameters(), module.trainResults, "train");
            this.printlnDebug(module.getModuleName() + " writing train file with full preds from scratch...");
        }
    }

    protected void loadModules() throws Exception {
        ErrorReport errors = new ErrorReport("Errors while loading modules from file. Directories given: " + Arrays.toString(this.readResultsFilesDirectories));
        for (int m = 0; m < this.modules.length; ++m) {
            File moduleTestResultsFile;
            String readResultsFilesDirectory = this.readResultsFilesDirectories.length == 1 ? this.readResultsFilesDirectories[0] : this.readResultsFilesDirectories[m];
            boolean trainResultsLoaded = false;
            boolean testResultsLoaded = false;
            File moduleTrainResultsFile = this.findResultsFile(readResultsFilesDirectory, this.modules[m].getModuleName(), "train");
            if (moduleTrainResultsFile != null) {
                this.printlnDebug(this.modules[m].getModuleName() + " train loading... " + moduleTrainResultsFile.getAbsolutePath());
                this.modules[m].trainResults = new ClassifierResults(moduleTrainResultsFile.getAbsolutePath());
                trainResultsLoaded = true;
            }
            if ((moduleTestResultsFile = this.findResultsFile(readResultsFilesDirectory, this.modules[m].getModuleName(), "test")) != null) {
                this.printlnDebug(this.modules[m].getModuleName() + " test loading..." + moduleTestResultsFile.getAbsolutePath());
                this.modules[m].testResults = new ClassifierResults(moduleTestResultsFile.getAbsolutePath());
                this.numTestInsts = this.modules[m].testResults.numInstances();
                testResultsLoaded = true;
            }
            if (!trainResultsLoaded) {
                errors.log("\nTRAIN results files for '" + this.modules[m].getModuleName() + "' on '" + this.datasetName + "' fold '" + this.resampleIdentifier + "' not found. ");
            } else if (this.needIndividualTrainPreds() && this.modules[m].trainResults.predictedClassProbabilities.isEmpty()) {
                errors.log("\nNo pred/distribution for instance data found in TRAIN results file for '" + this.modules[m].getModuleName() + "' on '" + this.datasetName + "' fold '" + this.resampleIdentifier + "'. ");
            }
            if (!testResultsLoaded) {
                errors.log("\nTEST results files for '" + this.modules[m].getModuleName() + "' on '" + this.datasetName + "' fold '" + this.resampleIdentifier + "' not found. ");
                continue;
            }
            if (this.modules[m].testResults.numInstances() != 0) continue;
            errors.log("\nNo prediction data found in TEST results file for '" + this.modules[m].getModuleName() + "' on '" + this.datasetName + "' fold '" + this.resampleIdentifier + "'. ");
        }
        errors.throwIfErrors();
    }

    protected boolean needIndividualTrainPreds() {
        return this.performEnsembleCV || this.weightingScheme.needTrainPreds || this.votingScheme.needTrainPreds;
    }

    protected File findResultsFile(String readResultsFilesDirectory, String classifierName, String trainOrTest) {
        File file = new File(readResultsFilesDirectory + classifierName + "/Predictions/" + this.datasetName + "/" + trainOrTest + "Fold" + this.resampleIdentifier + ".csv");
        if (!file.exists() || file.length() == 0L) {
            return null;
        }
        return file;
    }

    protected void writeResultsFile(String classifierName, String parameters, ClassifierResults results, String trainOrTest) throws IOException {
        StringBuilder st = new StringBuilder();
        st.append(this.datasetName).append(",").append(this.ensembleIdentifier).append(classifierName).append("," + trainOrTest + "\n");
        st.append(parameters + "\n");
        st.append(results.acc).append("\n");
        st.append(results.writeInstancePredictions());
        String fullPath = this.writeResultsFilesDirectory + classifierName + "/Predictions/" + this.datasetName;
        new File(fullPath).mkdirs();
        FileWriter out = new FileWriter(fullPath + "/" + trainOrTest + "Fold" + this.resampleIdentifier + ".csv");
        out.append(st);
        out.close();
    }

    public void setResultsFileLocationParameters(String individualResultsFilesDirectory, String datasetName, int resampleIdentifier) {
        this.resultsFilesParametersInitialised = true;
        this.readResultsFilesDirectories = new String[]{individualResultsFilesDirectory};
        this.datasetName = datasetName;
        this.resampleIdentifier = resampleIdentifier;
    }

    public void setResultsFileLocationParameters(String[] individualResultsFilesDirectories, String datasetName, int resampleIdentifier) {
        this.resultsFilesParametersInitialised = true;
        this.readResultsFilesDirectories = individualResultsFilesDirectories;
        this.datasetName = datasetName;
        this.resampleIdentifier = resampleIdentifier;
    }

    public void setResultsFileWritingLocation(String writeResultsFilesDirectory) {
        this.writeResultsFilesDirectory = writeResultsFilesDirectory;
    }

    public void setBuildIndividualsFromResultsFiles(boolean b) {
        this.readIndividualsResults = b;
        if (b) {
            this.writeIndividualsResults = false;
        }
    }

    public void setWriteIndividualsTrainResultsFiles(boolean b) {
        this.writeIndividualsResults = b;
        if (b) {
            this.readIndividualsResults = false;
        }
    }

    protected ClassifierResults doEnsembleCV(Instances data) throws Exception {
        double[] preds = new double[this.numTrainInsts];
        double[][] dists = new double[this.numTrainInsts][];
        double[] accPerFold = new double[this.cv.getNumFolds()];
        double correct = 0.0;
        for (int fold = 0; fold < this.cv.getNumFolds(); ++fold) {
            for (int i = 0; i < this.cv.getFoldIndices().get(fold).size(); ++i) {
                double actual;
                int instIndex = this.cv.getFoldIndices().get(fold).get(i);
                double[] dist = this.votingScheme.distributionForTrainInstance(this.modules, instIndex);
                double pred = GenericTools.indexOfMax(dist);
                if (pred == (actual = data.instance(instIndex).classValue())) {
                    correct += 1.0;
                    int n = fold;
                    accPerFold[n] = accPerFold[n] + 1.0;
                }
                preds[instIndex] = pred;
                dists[instIndex] = dist;
            }
            int n = fold;
            accPerFold[n] = accPerFold[n] / (double)this.cv.getFoldIndices().get(fold).size();
        }
        double acc = correct / (double)this.numTrainInsts;
        double stddevOverFolds = StatisticalUtilities.standardDeviation(accPerFold, false, acc);
        ClassifierResults trainResults = new ClassifierResults(acc, data.attributeToDoubleArray(data.classIndex()), preds, dists, stddevOverFolds, this.numClasses);
        trainResults.setNumClasses(this.numClasses);
        trainResults.setNumInstances(this.numTrainInsts);
        return trainResults;
    }

    public void finaliseIndividualModuleTestResults(double[] testSetClassVals) throws Exception {
        for (EnsembleModule module : this.modules) {
            module.testResults.finaliseResults(testSetClassVals);
        }
    }

    public void finaliseEnsembleTestResults(double[] testSetClassVals) throws Exception {
        this.ensembleTestResults.finaliseResults(testSetClassVals);
        this.ensembleTestResults.setNumClasses(this.numClasses);
        this.ensembleTestResults.setNumInstances(this.numTestInsts);
    }

    public void writeIndividualTestFiles(double[] testSetClassVals, boolean throwExceptionOnFileParamsNotSetProperly) throws Exception {
        if (!this.writeIndividualsResults || !this.resultsFilesParametersInitialised) {
            if (throwExceptionOnFileParamsNotSetProperly) {
                throw new Exception("to call writeIndividualTestFiles(), must have called setResultsFileLocationParameters(...) and setWriteIndividualsResultsFiles()");
            }
            return;
        }
        this.finaliseIndividualModuleTestResults(testSetClassVals);
        for (EnsembleModule module : this.modules) {
            this.writeResultsFile(module.getModuleName(), module.getParameters(), module.testResults, "test");
        }
    }

    public void writeEnsembleTrainTestFiles(double[] testSetClassVals, boolean throwExceptionOnFileParamsNotSetProperly) throws Exception {
        if (!this.resultsFilesParametersInitialised) {
            if (throwExceptionOnFileParamsNotSetProperly) {
                throw new Exception("to call writeEnsembleTrainTestFiles(), must have called setResultsFileLocationParameters(...)");
            }
            return;
        }
        if (this.ensembleTrainResults != null) {
            this.writeResultsFile(this.ensembleIdentifier, this.getParameters(), this.ensembleTrainResults, "train");
        }
        this.ensembleTestResults.finaliseResults(testSetClassVals);
        this.writeResultsFile(this.ensembleIdentifier, this.getParameters(), this.ensembleTestResults, "test");
    }

    public EnsembleModule[] getModules() {
        return this.modules;
    }

    public CrossValidator getCrossValidator() {
        return this.cv;
    }

    public String[] getClassifierNames() {
        String[] classifierNames = new String[this.modules.length];
        for (int m = 0; m < this.modules.length; ++m) {
            classifierNames[m] = this.modules[m].getModuleName();
        }
        return classifierNames;
    }

    @Override
    public double[] getEnsembleCvPreds() {
        return this.ensembleTrainResults.getPredClassVals();
    }

    @Override
    public double getEnsembleCvAcc() {
        return this.ensembleTrainResults.acc;
    }

    public String getEnsembleIdentifier() {
        return this.ensembleIdentifier;
    }

    public void setEnsembleIdentifier(String ensembleIdentifier) {
        this.ensembleIdentifier = ensembleIdentifier;
    }

    public double[][] getPosteriorIndividualWeights() {
        double[][] weights = new double[this.modules.length][];
        for (int m = 0; m < this.modules.length; ++m) {
            weights[m] = this.modules[m].posteriorWeights;
        }
        return weights;
    }

    public ModuleVotingScheme getVotingScheme() {
        return this.votingScheme;
    }

    public void setVotingScheme(ModuleVotingScheme votingScheme) {
        this.votingScheme = votingScheme;
    }

    public ModuleWeightingScheme getWeightingScheme() {
        return this.weightingScheme;
    }

    public void setWeightingScheme(ModuleWeightingScheme weightingScheme) {
        this.weightingScheme = weightingScheme;
    }

    public double[] getIndividualCvAccs() {
        double[] accs = new double[this.modules.length];
        for (int i = 0; i < this.modules.length; ++i) {
            accs[i] = this.modules[i].trainResults.acc;
        }
        return accs;
    }

    public double[][] getIndividualCvPredictions() {
        double[][] preds = new double[this.modules.length][];
        for (int i = 0; i < this.modules.length; ++i) {
            preds[i] = this.modules[i].trainResults.getPredClassVals();
        }
        return preds;
    }

    public SimpleBatchFilter getTransform() {
        return this.transform;
    }

    public void setTransform(SimpleBatchFilter transform) {
        this.transform = transform;
    }

    @Override
    public void writeCVTrainToFile(String path) {
        this.outputEnsembleTrainingPathAndFile = path;
        this.performEnsembleCV = true;
        this.writeEnsembleTrainingFile = true;
    }

    @Override
    public boolean findsTrainAccuracyEstimate() {
        return this.performEnsembleCV;
    }

    @Override
    public ClassifierResults getTrainResults() {
        return this.ensembleTrainResults;
    }

    public ClassifierResults getTestResults() {
        return this.ensembleTestResults;
    }

    @Override
    public String getParameters() {
        StringBuilder out = new StringBuilder();
        if (this.ensembleTrainResults != null) {
            out.append("BuildTime,").append(this.ensembleTrainResults.buildTime).append(",Trainacc,").append(this.ensembleTrainResults.acc).append(",");
        } else {
            out.append("BuildTime,").append("-1").append(",Trainacc,").append("-1").append(",");
        }
        out.append(this.weightingScheme.toString()).append(",").append(this.votingScheme.toString()).append(",");
        for (int m = 0; m < this.modules.length; ++m) {
            out.append(this.modules[m].getModuleName()).append("(").append(this.modules[m].priorWeight);
            for (int j = 0; j < this.modules[m].posteriorWeights.length; ++j) {
                out.append("/").append(this.modules[m].posteriorWeights[j]);
            }
            out.append("),");
        }
        return out.toString();
    }

    @Override
    public double[] distributionForInstance(Instance instance) throws Exception {
        Instance ins = instance;
        if (this.transform != null) {
            Instances rawContainer = new Instances(instance.dataset(), 0);
            rawContainer.add(instance);
            Instances converted = this.transform.process(rawContainer);
            ins = converted.instance(0);
        }
        if (this.ensembleTestResults == null || this.testInstCounter == 0 && this.prevTestInstance == null) {
            this.printlnDebug("\n**TEST**");
            this.ensembleTestResults = new ClassifierResults(this.numClasses);
        }
        if (this.readIndividualsResults && this.testInstCounter >= this.numTestInsts) {
            throw new Exception("Received more test instances than expected, when loading test results files, found " + this.numTestInsts + " test cases");
        }
        double[] dist = this.readIndividualsResults ? this.votingScheme.distributionForTestInstance(this.modules, this.testInstCounter) : this.votingScheme.distributionForInstance(this.modules, ins);
        this.ensembleTestResults.storeSingleResult(dist);
        if (this.prevTestInstance != instance) {
            ++this.testInstCounter;
        }
        this.prevTestInstance = instance;
        return dist;
    }

    @Override
    public double classifyInstance(Instance instance) throws Exception {
        double[] dist = this.distributionForInstance(instance);
        return GenericTools.indexOfMax(dist);
    }

    public double[] classifyInstanceByConstituents(Instance instance) throws Exception {
        Instance ins = instance;
        if (this.transform != null) {
            Instances rawContainer = new Instances(instance.dataset(), 0);
            rawContainer.add(instance);
            Instances converted = this.transform.process(rawContainer);
            ins = converted.instance(0);
        }
        double[] predsByClassifier = new double[this.modules.length];
        for (int i = 0; i < this.modules.length; ++i) {
            predsByClassifier[i] = this.modules[i].getClassifier().classifyInstance(ins);
        }
        return predsByClassifier;
    }

    public double[][] distributionForInstanceByConstituents(Instance instance) throws Exception {
        Instance ins = instance;
        if (this.transform != null) {
            Instances rawContainer = new Instances(instance.dataset(), 0);
            rawContainer.add(instance);
            Instances converted = this.transform.process(rawContainer);
            ins = converted.instance(0);
        }
        double[][] distsByClassifier = new double[this.modules.length][];
        for (int i = 0; i < this.modules.length; ++i) {
            distsByClassifier[i] = this.modules[i].getClassifier().distributionForInstance(ins);
        }
        return distsByClassifier;
    }

    public static void buildAndWriteFullIndividualTrainTestResults(Instances defaultTrainPartition, Instances defaultTestPartition, String resultOutputDir, String datasetIdentifier, String ensembleIdentifier, int resampleIdentifier, Classifier[] classifiers, String[] cNames, SimpleBatchFilter transform, boolean setSeed, boolean resample, boolean writeEnsembleResults) throws Exception {
        Instances train = new Instances(defaultTrainPartition);
        Instances test = new Instances(defaultTestPartition);
        if (resample && resampleIdentifier > 0) {
            Instances[] temp = InstanceTools.resampleTrainAndTestInstances(train, test, resampleIdentifier);
            train = temp[0];
            test = temp[1];
        }
        CAWPE h = new CAWPE();
        if (classifiers != null) {
            h.setClassifiers(classifiers, cNames, null);
        }
        h.setTransform(transform);
        if (setSeed) {
            h.setRandSeed(resampleIdentifier);
        }
        h.setResultsFileLocationParameters(resultOutputDir, datasetIdentifier, resampleIdentifier);
        h.setWriteIndividualsTrainResultsFiles(true);
        if (writeEnsembleResults) {
            h.setPerformCV(true);
        }
        h.buildClassifier(train);
        for (Instance inst : test) {
            h.distributionForInstance(inst);
        }
        double[] classVals = test.attributeToDoubleArray(test.classIndex());
        h.writeIndividualTestFiles(classVals, true);
        if (writeEnsembleResults) {
            h.writeEnsembleTrainTestFiles(classVals, true);
        }
    }

    public static void exampleCAWPEUsage() throws Exception {
        String datasetName = "ItalyPowerDemand";
        Instances train = ClassifierTools.loadData("c:/tsc problems/" + datasetName + "/" + datasetName + "_TRAIN");
        Instances test = ClassifierTools.loadData("c:/tsc problems/" + datasetName + "/" + datasetName + "_TEST");
        CAWPE cawpe = new CAWPE();
        SAX transform = new SAX();
        cawpe.setTransform(transform);
        cawpe.setTransform(null);
        Classifier[] classifiers = new Classifier[]{new kNN()};
        String[] names = new String[]{"NN"};
        String[] params = new String[]{"k=1"};
        cawpe.setClassifiers(classifiers, names, params);
        cawpe.setWeightingScheme(new TrainAccByClass());
        cawpe.setVotingScheme(new MajorityVote());
        cawpe.setDefaultCAWPESettings();
        int resampleID = 0;
        cawpe.setRandSeed(resampleID);
        cawpe.setResultsFileLocationParameters("CAWPETest/", datasetName, resampleID);
        cawpe.setBuildIndividualsFromResultsFiles(true);
        cawpe.setWriteIndividualsTrainResultsFiles(true);
        cawpe.buildClassifier(train);
        System.out.println(ClassifierTools.accuracy(test, cawpe));
        boolean throwExceptionOnFileParamsNotSetProperly = false;
        cawpe.writeIndividualTestFiles(test.attributeToDoubleArray(test.classIndex()), throwExceptionOnFileParamsNotSetProperly);
        cawpe.writeEnsembleTrainTestFiles(test.attributeToDoubleArray(test.classIndex()), throwExceptionOnFileParamsNotSetProperly);
    }

    public String buildEnsembleReport(boolean printPreds, boolean builtFromFile) {
        StringBuilder sb = new StringBuilder();
        sb.append("CAWPE REPORT");
        sb.append("\nname: ").append(this.ensembleIdentifier);
        sb.append("\nmodules: ").append(this.modules[0].getModuleName());
        for (int i = 1; i < this.modules.length; ++i) {
            sb.append(",").append(this.modules[i].getModuleName());
        }
        sb.append("\nweight scheme: ").append(this.weightingScheme);
        sb.append("\nvote scheme: ").append(this.votingScheme);
        sb.append("\ndataset: ").append(this.datasetName);
        sb.append("\nfold: ").append(this.resampleIdentifier);
        sb.append("\ntrain acc: ").append(this.ensembleTrainResults.acc);
        sb.append("\ntest acc: ").append(builtFromFile ? Double.valueOf(this.ensembleTestResults.acc) : "NA");
        int precision = 4;
        int numWidth = precision + 2;
        int trainAccColWidth = 8;
        int priorWeightColWidth = 12;
        int postWeightColWidth = 12;
        String moduleHeaderFormatString = "\n\n%20s | %" + Math.max(trainAccColWidth, numWidth) + "s | %" + Math.max(priorWeightColWidth, numWidth) + "s | %" + Math.max(postWeightColWidth, this.numClasses * (numWidth + 2)) + "s";
        String moduleRowHeaderFormatString = "\n%20s | %" + trainAccColWidth + "." + precision + "f | %" + priorWeightColWidth + "." + precision + "f | %" + Math.max(postWeightColWidth, this.numClasses * (precision + 2)) + "s";
        sb.append(String.format(moduleHeaderFormatString, "modules", "trainacc", "priorweights", "postweights"));
        for (EnsembleModule module : this.modules) {
            String postweights = String.format("  %." + precision + "f", module.posteriorWeights[0]);
            for (int c = 1; c < this.numClasses; ++c) {
                postweights = postweights + String.format(", %." + precision + "f", module.posteriorWeights[c]);
            }
            sb.append(String.format(moduleRowHeaderFormatString, module.getModuleName(), module.trainResults.acc, module.priorWeight, postweights));
        }
        if (printPreds) {
            int i;
            sb.append("\n\nensemble train preds: ");
            sb.append("\ntrain acc: ").append(this.ensembleTrainResults.acc);
            sb.append("\n");
            for (i = 0; i < this.ensembleTrainResults.numInstances(); ++i) {
                sb.append(this.buildEnsemblePredsLine(true, i)).append("\n");
            }
            sb.append("\n\nensemble test preds: ");
            sb.append("\ntest acc: ").append(builtFromFile ? Double.valueOf(this.ensembleTestResults.acc) : "NA");
            sb.append("\n");
            for (i = 0; i < this.ensembleTestResults.numInstances(); ++i) {
                sb.append(this.buildEnsemblePredsLine(false, i)).append("\n");
            }
        }
        return sb.toString();
    }

    private String buildEnsemblePredsLine(boolean train, int index) {
        int m;
        int j;
        double[] pred;
        StringBuilder sb = new StringBuilder();
        if (train) {
            sb.append(this.modules[0].trainResults.actualClassValues.get(index)).append(",").append(this.ensembleTrainResults.predictedClassValues.get(index)).append(",");
        } else {
            sb.append(this.modules[0].testResults.actualClassValues.get(index)).append(",").append(this.ensembleTestResults.predictedClassValues.get(index)).append(",");
        }
        if (train) {
            pred = this.ensembleTrainResults.getDistributionForInstance(index);
            for (j = 0; j < pred.length; ++j) {
                sb.append(",").append(pred[j]);
            }
        } else {
            pred = this.ensembleTestResults.getDistributionForInstance(index);
            for (j = 0; j < pred.length; ++j) {
                sb.append(",").append(pred[j]);
            }
        }
        sb.append(",");
        double[] predDist = new double[this.numClasses];
        for (m = 0; m < this.modules.length; ++m) {
            if (train) {
                int n = (int)this.modules[m].trainResults.getPredClassValue(index);
                predDist[n] = predDist[n] + 1.0;
                continue;
            }
            int n = (int)this.modules[m].testResults.getPredClassValue(index);
            predDist[n] = predDist[n] + 1.0;
        }
        for (int c = 0; c < this.numClasses; ++c) {
            sb.append(",").append(predDist[c]);
        }
        sb.append(",");
        for (m = 0; m < this.modules.length; ++m) {
            if (train) {
                sb.append(",").append(this.modules[m].trainResults.getPredClassValue(index));
                continue;
            }
            sb.append(",").append(this.modules[m].testResults.getPredClassValue(index));
        }
        return sb.toString();
    }

    public static void testBuildingInds(int testID) throws Exception {
        System.out.println("testBuildingInds()");
        new File("C:/JamesLPHD/CAWPETests" + testID + "/").mkdirs();
        int numFolds = 5;
        for (int fold = 0; fold < numFolds; ++fold) {
            String dataset = "breast-cancer-wisc-prog";
            Instances all = ClassifierTools.loadData("C:/UCI Problems/" + dataset + "/" + dataset);
            Instances[] insts = InstanceTools.resampleInstances(all, fold, 0.5);
            Instances train = insts[0];
            Instances test = insts[1];
            CAWPE cawpe = new CAWPE();
            cawpe.setResultsFileLocationParameters("C:/JamesLPHD/CAWPETests" + testID + "/", dataset, fold);
            cawpe.setWriteIndividualsTrainResultsFiles(true);
            cawpe.setPerformCV(true);
            cawpe.setRandSeed(fold);
            cawpe.buildClassifier(train);
            double acc = 0.0;
            for (Instance instance : test) {
                if (instance.classValue() != cawpe.classifyInstance(instance)) continue;
                acc += 1.0;
            }
            acc /= (double)test.numInstances();
            cawpe.writeIndividualTestFiles(test.attributeToDoubleArray(test.classIndex()), true);
            cawpe.writeEnsembleTrainTestFiles(test.attributeToDoubleArray(test.classIndex()), true);
            System.out.println("TrainAcc=" + cawpe.getTrainResults().acc);
            System.out.println("TestAcc=" + cawpe.getTestResults().acc);
        }
    }

    public static void testLoadingInds(int testID) throws Exception {
        System.out.println("testBuildingInds()");
        new File("C:/JamesLPHD/CAWPETests" + testID + "/").mkdirs();
        int numFolds = 5;
        for (int fold = 0; fold < numFolds; ++fold) {
            String dataset = "breast-cancer-wisc-prog";
            Instances all = ClassifierTools.loadData("C:/UCI Problems/" + dataset + "/" + dataset);
            Instances[] insts = InstanceTools.resampleInstances(all, fold, 0.5);
            Instances train = insts[0];
            Instances test = insts[1];
            CAWPE cawpe = new CAWPE();
            cawpe.setResultsFileLocationParameters("C:/JamesLPHD/CAWPETests" + testID + "/", dataset, fold);
            cawpe.setBuildIndividualsFromResultsFiles(true);
            cawpe.setPerformCV(true);
            cawpe.setRandSeed(fold);
            cawpe.buildClassifier(train);
            double acc = 0.0;
            for (Instance instance : test) {
                if (instance.classValue() != cawpe.classifyInstance(instance)) continue;
                acc += 1.0;
            }
            acc /= (double)test.numInstances();
            cawpe.finaliseEnsembleTestResults(test.attributeToDoubleArray(test.classIndex()));
            System.out.println("TrainAcc=" + cawpe.getTrainResults().acc);
            System.out.println("TestAcc=" + cawpe.getTestResults().acc);
        }
    }

    public static void buildCAWPEPaper_BuildClassifierResultsFiles(String baseWritePath, String[] dataHeaders, String[] dataPaths, String[][] datasetNames, String[] classifiers, int numFolds) throws Exception {
        for (int archive = 0; archive < dataHeaders.length; ++archive) {
            for (String classifier : classifiers) {
                System.out.println("\t" + classifier);
                for (String dset : datasetNames[archive]) {
                    System.out.println(dset);
                    for (int fold = 0; fold < numFolds; ++fold) {
                        Experiments.main(new String[]{dataPaths[archive], baseWritePath + dataHeaders[archive] + "/", "true", classifier, dset, "" + (fold + 1)});
                    }
                }
            }
        }
    }

    public static void buildCAWPEPaper_AllResultsForFigure2() throws Exception {
        String[] dataHeaders = new String[]{"UCI"};
        String[] dataPaths = new String[]{"Z:/Data/UCIContinuous/"};
        String[][] datasets = new String[][]{DataSets.UCIContinuousFileNames};
        String writePathBase = "Z:/Results/CAWPEReproducabiltyTest/";
        String writePathResults = writePathBase + "Results/";
        String writePathAnalysis = writePathBase + "Analysis/";
        int numFolds = 30;
        String[] baseClassifiers = new String[]{"NN", "C45", "MLP", "Logistic", "SVML"};
        CAWPE.buildCAWPEPaper_BuildClassifierResultsFiles(writePathResults, dataHeaders, dataPaths, datasets, baseClassifiers, numFolds);
        String[] ensembleIDsInStorage = new String[]{"CAWPE_BasicClassifiers", "EnsembleSelection_BasicClassifiers", "SMLR_BasicClassifiers", "SMLRE_BasicClassifiers", "SMM5_BasicClassifiers", "PickBest_BasicClassifiers", "MajorityVote_BasicClassifiers", "WeightMajorityVote_BasicClassifiers", "RecallCombiner_BasicClassifiers", "NaiveBayesCombiner_BasicClassifiers"};
        String[] ensembleIDsOnFigures = new String[]{"CAWPE", "ES", "SMLR", "SMLRE", "SMM5", "PB", "MV", "WMV", "RC", "NBC"};
        Class[] ensembleClasses = new Class[]{Class.forName("vector_classifiers.CAWPE"), Class.forName("vector_classifiers.EnsembleSelection"), Class.forName("vector_classifiers.stackers.SMLR"), Class.forName("vector_classifiers.stackers.SMLRE"), Class.forName("vector_classifiers.stackers.SMM5"), Class.forName("vector_classifiers.weightedvoters.CAWPE_PickBest"), Class.forName("vector_classifiers.weightedvoters.CAWPE_MajorityVote"), Class.forName("vector_classifiers.weightedvoters.CAWPE_WeightedMajorityVote"), Class.forName("vector_classifiers.weightedvoters.CAWPE_RecallCombiner"), Class.forName("vector_classifiers.weightedvoters.CAWPE_NaiveBayesCombiner")};
        for (int ensemble = 0; ensemble < ensembleIDsInStorage.length; ++ensemble) {
            CAWPE.buildCAWPEPaper_BuildEnsembleFromResultsFiles(writePathResults, dataHeaders, dataPaths, datasets, baseClassifiers, numFolds, ensembleIDsInStorage[ensemble], ensembleClasses[ensemble]);
        }
        for (int archive = 0; archive < dataHeaders.length; ++archive) {
            String analysisName = dataHeaders[archive] + "CAWPEvsHeteroEnsembles_BasicClassifiers";
            CAWPE.buildCAWPEPaper_BuildResultsAnalysis(writePathResults + dataHeaders[archive] + "/", writePathAnalysis, analysisName, ensembleIDsInStorage, ensembleIDsOnFigures, datasets[archive], numFolds);
        }
    }

    public static void buildCAWPEPaper_BuildResultsAnalysis(String resultsReadPath, String analysisWritePath, String analysisName, String[] classifiersInStorage, String[] classifiersOnFigs, String[] datasets, int numFolds) throws Exception {
        System.out.println("buildCAWPEPaper_BuildResultsAnalysis");
        new MultipleClassifierEvaluation(analysisWritePath, analysisName, numFolds).setTestResultsOnly(false).setBuildMatlabDiagrams(true).setDatasets(datasets).readInClassifiers(classifiersInStorage, classifiersOnFigs, resultsReadPath).runComparison();
    }

    public static void buildCAWPEPaper_BuildEnsembleFromResultsFiles(String baseWritePath, String[] dataHeaders, String[] dataPaths, String[][] datasetNames, String[] baseClassifiers, int numFolds, String ensembleID, Class ensembleClass) throws Exception {
        Instances train = null;
        Instances test = null;
        Instances all = null;
        Instances[] data = null;
        for (int archive = 0; archive < dataHeaders.length; ++archive) {
            String writePath = baseWritePath + dataHeaders[archive] + "/";
            for (String dset : datasetNames[archive]) {
                System.out.println(dset);
                if (dataHeaders[archive].equals("UCI")) {
                    all = ClassifierTools.loadData(dataPaths[archive] + dset + "/" + dset + ".arff");
                } else if (dataHeaders[archive].contains("UCR")) {
                    train = ClassifierTools.loadData(dataPaths[archive] + dset + "/" + dset + "_TRAIN.arff");
                    test = ClassifierTools.loadData(dataPaths[archive] + dset + "/" + dset + "_TEST.arff");
                }
                for (int fold = 0; fold < numFolds; ++fold) {
                    String predictions = writePath + ensembleID + "/Predictions/" + dset;
                    File f = new File(predictions);
                    if (!f.exists()) {
                        f.mkdirs();
                    }
                    if (CollateResults.validateSingleFoldFile(predictions + "/testFold" + fold + ".csv")) continue;
                    if (dataHeaders[archive].equals("UCI")) {
                        data = InstanceTools.resampleInstances(all, fold, 0.5);
                    } else if (dataHeaders[archive].contains("UCR")) {
                        data = InstanceTools.resampleTrainAndTestInstances(train, test, fold);
                    }
                    CAWPE c = (CAWPE)ensembleClass.getConstructor(new Class[0]).newInstance(new Object[0]);
                    c.setClassifiers(null, baseClassifiers, null);
                    c.setBuildIndividualsFromResultsFiles(true);
                    c.setResultsFileLocationParameters(writePath, dset, fold);
                    c.setRandSeed(fold);
                    c.setPerformCV(true);
                    Experiments.singleClassifierAndFoldTrainTestSplit(data[0], data[1], c, fold, predictions);
                }
            }
        }
    }

    public static void main(String[] args) throws Exception {
        CAWPE.buildCAWPEPaper_AllResultsForFigure2();
    }
}

