/*
 * Decompiled with CFR 0.152.
 */
package vector_classifiers;

import development.CollateResults;
import development.DataSets;
import development.Experiments;
import java.io.File;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.LinkedList;
import java.util.List;
import java.util.Random;
import timeseriesweka.classifiers.ensembles.EnsembleModule;
import timeseriesweka.classifiers.ensembles.voting.MajorityVote;
import timeseriesweka.classifiers.ensembles.weightings.EqualWeighting;
import utilities.ClassifierResults;
import utilities.ClassifierTools;
import utilities.GenericTools;
import utilities.InstanceTools;
import vector_classifiers.CAWPE;
import weka.core.Instances;

public class EnsembleSelection
extends CAWPE {
    Integer numBags = null;
    Double propOfModelsInEachBag = null;
    Integer numOfTopModelsToInitialiseBagWith = null;
    final int MAX_SUBENSEMBLE_SIZE = 100;
    Random rng;
    public static String[] CAWPE_basic = new String[]{"NN", "SVML", "C4.5", "Logistic", "MLP"};
    public static String[] CAWPE_bigClassifierList = new String[]{"RotFDefault", "RandF", "SVMQ", "NN", "SVML", "C4.5", "NB", "bayesNet", "DaggingDefault", "MultiBoostABDefault", "AdaBoostM1Default", "BaggingDefault", "LogitBoostDefault", "DecorateDefault", "ENDDefault", "RandomCommitteeDefault", "Logistic", "MLP", "DNN", "1NN", "DecisionTable", "REPTree"};

    public EnsembleSelection() {
        this.ensembleIdentifier = "EnsembleSelection";
        this.votingScheme = new MajorityVote();
        this.weightingScheme = new EqualWeighting();
        this.rng = new Random(0L);
    }

    public Integer getNumBags() {
        return this.numBags;
    }

    public void setNumBags(Integer numBags) {
        this.numBags = numBags;
    }

    public Double getPropOfModelsInEachBag() {
        return this.propOfModelsInEachBag;
    }

    public void setPropOfModelsInEachBag(Double propOfModelsInEachBag) {
        this.propOfModelsInEachBag = propOfModelsInEachBag;
    }

    public Integer getNumOfTopModelsToInitialiseBagWith() {
        return this.numOfTopModelsToInitialiseBagWith;
    }

    public void setNumOfTopModelsToInitialiseBagWith(Integer numOfTopModelsToInitialiseBagWith) {
        this.numOfTopModelsToInitialiseBagWith = numOfTopModelsToInitialiseBagWith;
    }

    @Override
    public void setRandSeed(int seed) {
        super.setRandSeed(seed);
        this.rng = new Random(seed);
    }

    @Override
    public void buildClassifier(Instances data) throws Exception {
        long buildTime;
        this.printlnDebug("**EnsembleSelection TRAIN**");
        if (this.resultsFilesParametersInitialised) {
            if (this.readResultsFilesDirectories.length > 1 && this.readResultsFilesDirectories.length != this.modules.length) {
                throw new Exception("EnsembleSelection.buildClassifier: more than one results path given, but number given does not align with the number of classifiers/modules.");
            }
            if (this.writeResultsFilesDirectory == null) {
                this.writeResultsFilesDirectory = this.readResultsFilesDirectories[0];
            }
        }
        long startTime = System.currentTimeMillis();
        this.train = this.transform == null ? new Instances(data) : this.transform.process(data);
        this.numTrainInsts = this.train.numInstances();
        this.numClasses = this.train.numClasses();
        this.numAttributes = this.train.numAttributes();
        this.initialiseModules();
        this.weightingScheme.defineWeightings(this.modules, this.numClasses);
        this.votingScheme.trainVotingScheme(this.modules, this.numClasses);
        if (this.numBags == null) {
            this.numBags = 10;
        }
        if (this.propOfModelsInEachBag == null) {
            this.propOfModelsInEachBag = 0.5;
        }
        if (this.numOfTopModelsToInitialiseBagWith == null) {
            this.numOfTopModelsToInitialiseBagWith = 1;
        }
        int numModelsInEachBag = Math.max(1, (int)Math.round(this.propOfModelsInEachBag * (double)this.modules.length));
        ArrayList subensembles = new ArrayList(this.numBags);
        ClassifierResults globalEnsembleResults = null;
        for (int bagID = 0; bagID < this.numBags; ++bagID) {
            boolean finished;
            List<EnsembleModule> list = this.sample(this.modules, numModelsInEachBag);
            ArrayList<EnsembleModule> subensemble = new ArrayList<EnsembleModule>();
            ClassifierResults subEnsembleResults = null;
            if (this.numOfTopModelsToInitialiseBagWith != null && this.numOfTopModelsToInitialiseBagWith > 0) {
                int lastInd = list.size() - 1;
                EnsembleModule model = list.get(lastInd);
                subensemble.add(model);
                subEnsembleResults = model.trainResults;
                for (int i = 1; i < this.numOfTopModelsToInitialiseBagWith; ++i) {
                    model = list.get(lastInd - i);
                    subensemble.add(model);
                    subEnsembleResults = this.combinePredictions(subEnsembleResults, i, model.trainResults);
                }
            }
            double newAcc = subEnsembleResults == null ? 0.0 : subEnsembleResults.acc;
            do {
                finished = true;
                double accSoFar = newAcc;
                ClassifierResults[] candidateResults = new ClassifierResults[list.size()];
                double[] accs = new double[list.size()];
                for (int modelID = 0; modelID < list.size(); ++modelID) {
                    candidateResults[modelID] = this.combinePredictions(subEnsembleResults, subensemble.size(), list.get((int)modelID).trainResults);
                    accs[modelID] = candidateResults[modelID].acc;
                }
                int maxAccInd = (int)GenericTools.indexOfMax(accs);
                newAcc = accs[maxAccInd];
                if (!(newAcc > accSoFar)) continue;
                finished = false;
                subEnsembleResults = candidateResults[maxAccInd];
                subensemble.add(list.get(maxAccInd));
                if (subensemble.size() < 100) continue;
                finished = true;
            } while (!finished);
            subensembles.add(subensemble);
            globalEnsembleResults = globalEnsembleResults == null ? subEnsembleResults : this.combinePredictions(globalEnsembleResults, bagID, subEnsembleResults);
        }
        for (EnsembleModule module : this.modules) {
            module.priorWeight = 0.0;
        }
        for (List list : subensembles) {
            for (EnsembleModule model : list) {
                int ind;
                for (ind = 0; ind < this.modules.length && model != this.modules[ind]; ++ind) {
                }
                assert (ind != this.modules.length);
                this.modules[ind].priorWeight += 1.0;
            }
        }
        this.ensembleTrainResults = globalEnsembleResults;
        this.ensembleTrainResults.setName("EnsembleSelection");
        this.ensembleTrainResults.buildTime = buildTime = System.currentTimeMillis() - startTime;
        if (this.writeEnsembleTrainingFile) {
            this.writeEnsembleCVResults(this.train);
        }
        this.testInstCounter = 0;
    }

    public List<EnsembleModule> sample(EnsembleModule[] pool, int numToPick) {
        LinkedList<EnsembleModule> pooll = new LinkedList<EnsembleModule>();
        for (EnsembleModule module : pool) {
            pooll.add(module);
        }
        ArrayList<EnsembleModule> res = new ArrayList<EnsembleModule>(numToPick);
        for (int i = 0; i < numToPick; ++i) {
            int toRemove = this.rng.nextInt(pooll.size());
            res.add((EnsembleModule)pooll.remove(toRemove));
        }
        return res;
    }

    public ClassifierResults combinePredictions(ClassifierResults ensembleSoFarResults, int ensembleSizeSoFar, ClassifierResults newModelResults) throws Exception {
        ClassifierResults newResults = new ClassifierResults(this.numClasses);
        for (int inst = 0; inst < ensembleSoFarResults.predictedClassProbabilities.size(); ++inst) {
            double[] ensDist = ensembleSoFarResults.predictedClassProbabilities.get(inst);
            double[] indDist = newModelResults.predictedClassProbabilities.get(inst);
            assert (ensDist.length == this.numClasses);
            assert (indDist.length == this.numClasses);
            double[] newDist = new double[this.numClasses];
            for (int c = 0; c < this.numClasses; ++c) {
                newDist[c] = (ensDist[c] * (double)ensembleSizeSoFar + indDist[c]) / (double)(ensembleSizeSoFar + 1);
            }
            newResults.storeSingleResult(newDist);
        }
        newResults.finaliseResults(ensembleSoFarResults.getTrueClassVals());
        return newResults;
    }

    public static void main(String[] args) throws Exception {
        EnsembleSelection.tests();
    }

    public static void tests() {
        String resPath = "C:/JamesLPHD/HESCA/UCI/UCIResults/";
        int numfolds = 30;
        String[] dsets = DataSets.UCIContinuousFileNames;
        String[] skipDsets = new String[]{};
        String classifier = "EnsembleSelectionAll22Classifiers_Preds";
        for (String dset : dsets) {
            if (Arrays.asList(skipDsets).contains(dset)) continue;
            System.out.println(dset);
            Instances all = ClassifierTools.loadData("C:/UCI Problems/" + dset + "/" + dset + ".arff");
            for (int fold = 0; fold < numfolds; ++fold) {
                String predictions = resPath + classifier + "/Predictions/" + dset;
                File f = new File(predictions);
                if (!f.exists()) {
                    f.mkdirs();
                }
                if (CollateResults.validateSingleFoldFile(predictions + "/testFold" + fold + ".csv")) continue;
                Instances[] data = InstanceTools.resampleInstances(all, fold, 0.5);
                EnsembleSelection c = new EnsembleSelection();
                c.setClassifiers(null, CAWPE_bigClassifierList, null);
                c.setNumOfTopModelsToInitialiseBagWith(2);
                c.setBuildIndividualsFromResultsFiles(true);
                c.setResultsFileLocationParameters(resPath, dset, fold);
                c.setRandSeed(fold);
                c.setPerformCV(true);
                c.setResultsFileWritingLocation(resPath);
                Experiments.singleClassifierAndFoldTrainTestSplit(data[0], data[1], c, fold, predictions);
            }
        }
    }

    public static class SortByTrainAcc
    implements Comparator<EnsembleModule> {
        @Override
        public int compare(EnsembleModule o1, EnsembleModule o2) {
            return Double.compare(o1.trainResults.acc, o2.trainResults.acc);
        }
    }
}

