/*
 * Decompiled with CFR 0.152.
 */
package vector_classifiers;

import development.CollateResults;
import fileIO.OutFile;
import java.io.File;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Random;
import java.util.logging.Level;
import java.util.logging.Logger;
import timeseriesweka.classifiers.ParameterSplittable;
import utilities.ClassifierResults;
import utilities.ClassifierTools;
import utilities.CrossValidator;
import utilities.InstanceTools;
import utilities.SaveParameterInfo;
import utilities.TrainAccuracyEstimate;
import vector_classifiers.SaveEachParameter;
import vector_classifiers.TunedSVM;
import weka.classifiers.meta.Bagging;
import weka.classifiers.trees.RandomForest;
import weka.classifiers.trees.RandomTree;
import weka.core.Instances;
import weka.core.Utils;

public class TunedRandomForest
extends RandomForest
implements SaveParameterInfo,
TrainAccuracyEstimate,
SaveEachParameter,
ParameterSplittable {
    boolean tuneParameters = true;
    int[] paraSpace1;
    int[] paraSpace2;
    int[] paraSpace3;
    int[] paras;
    int maxPerPara = 10;
    String trainPath = "";
    int seed;
    Random rng;
    ArrayList<Double> accuracy;
    boolean crossValidate = true;
    boolean estimateAcc = true;
    private long combinedBuildTime;
    protected String resultsPath;
    protected boolean saveEachParaAcc = false;
    private int numFeaturesInProblem = 0;
    private static int MAX_FOLDS = 10;
    private ClassifierResults res = new ClassifierResults();

    public void setNumFeaturesInProblem(int m) {
        this.numFeaturesInProblem = m;
    }

    public void setNumFeaturesForEachTree(int m) {
        this.m_numFeatures = m;
    }

    public void setCrossValidate(boolean b) {
        if (b) {
            this.setEstimateAcc(b);
        }
        this.crossValidate = b;
    }

    public void setEstimateAcc(boolean b) {
        this.estimateAcc = b;
    }

    @Override
    public void setPathToSaveParameters(String r) {
        this.resultsPath = r;
        this.setSaveEachParaAcc(true);
    }

    @Override
    public void setSaveEachParaAcc(boolean b) {
        this.saveEachParaAcc = b;
    }

    @Override
    public String getParas() {
        return this.getParameters();
    }

    @Override
    public double getAcc() {
        return this.res.acc;
    }

    @Override
    public void setParametersFromIndex(int x) {
        this.tuneParameters = false;
        this.paras = new int[3];
        if (x < 1 || x > this.maxPerPara * this.maxPerPara * this.maxPerPara) {
            throw new UnsupportedOperationException("ERROR parameter index " + x + " out of range for PolyNomialKernel");
        }
        int numLevelsIndex = (x - 1) / (this.maxPerPara * this.maxPerPara);
        int numFeaturesIndex = (x - 1) / this.maxPerPara % this.maxPerPara;
        int numTreesIndex = x % this.maxPerPara;
        if (this.numFeaturesInProblem == 0) {
            throw new RuntimeException("Error in TunedRandomForest in set ParametersFromIndex: we do not know the number of attributes, need to call setNumFeaturesInProblem before this call");
        }
        this.paras[0] = numLevelsIndex == 0 ? 0 : numLevelsIndex * (this.numFeaturesInProblem / this.maxPerPara);
        this.paras[1] = numFeaturesIndex == 0 ? (int)Math.sqrt(this.numFeaturesInProblem) : (numFeaturesIndex == 1 ? (int)Utils.log2(this.numFeaturesInProblem) + 1 : (numFeaturesIndex - 1) * this.numFeaturesInProblem / this.maxPerPara);
        this.paras[2] = numTreesIndex == 0 ? 10 : 100 * numTreesIndex;
        this.setMaxDepth(this.paras[0]);
        this.setNumFeaturesForEachTree(this.paras[1]);
        this.setNumTrees(this.paras[2]);
        if (this.m_Debug) {
            System.out.println("Index =" + x + " Num Features =" + this.numFeaturesInProblem + " Max Depth=" + this.paras[0] + " Num Features =" + this.paras[1] + " Num Trees =" + this.paras[2]);
        }
    }

    @Override
    public ClassifierResults getTrainResults() {
        return this.res;
    }

    @Override
    public String getParameters() {
        String result = "BuildTime," + this.res.buildTime + ",CVAcc," + this.res.acc + ",";
        result = result + "MaxDepth," + this.getMaxDepth() + ",NumFeatures," + this.getNumFeatures() + ",NumTrees," + this.getNumTrees();
        return result;
    }

    @Override
    public void setParamSearch(boolean b) {
        this.tuneParameters = b;
    }

    public TunedRandomForest() {
        this.m_numTrees = 500;
        this.m_numExecutionSlots = 1;
        this.m_bagger = new EnhancedBagging();
        this.rng = new Random();
        this.seed = 0;
        this.accuracy = new ArrayList();
    }

    @Override
    public void setSeed(int s) {
        super.setSeed(s);
        this.seed = s;
        this.rng = new Random();
        this.rng.setSeed(this.seed);
    }

    public void debug(boolean b) {
        this.m_Debug = b;
    }

    public void tuneParameters(boolean b) {
        this.tuneParameters = b;
    }

    public void setNumTreesRange(int[] d) {
        this.paraSpace1 = d;
    }

    public void setNumFeaturesRange(int[] d) {
        this.paraSpace2 = d;
    }

    @Override
    public void writeCVTrainToFile(String train) {
        this.trainPath = train;
        this.estimateAcc = true;
    }

    @Override
    public boolean findsTrainAccuracyEstimate() {
        return this.estimateAcc;
    }

    protected final void setStandardParaSearchSpace(int m) {
        int i;
        if (m < this.maxPerPara) {
            this.maxPerPara = m;
        }
        if (this.m_Debug) {
            System.out.println("Number of features =" + m + " max para values =" + this.maxPerPara);
            System.out.println("Setting defaults ....");
        }
        this.paraSpace1 = new int[this.maxPerPara];
        this.paraSpace1[0] = 0;
        for (i = 1; i < this.paraSpace1.length; ++i) {
            this.paraSpace1[i] = this.paraSpace1[i - 1] + m / (this.paraSpace1.length - 1);
        }
        this.paraSpace2 = new int[this.maxPerPara];
        this.paraSpace2[0] = (int)Math.sqrt(m);
        this.paraSpace2[1] = (int)Utils.log2(m) + 1;
        for (i = 2; i < this.maxPerPara; ++i) {
            this.paraSpace2[i] = (i - 1) * m / this.maxPerPara;
        }
        this.paraSpace3 = new int[10];
        this.paraSpace3[0] = 10;
        for (i = 1; i < this.paraSpace3.length; ++i) {
            this.paraSpace3[i] = 100 * i;
        }
        if (this.m_Debug) {
            System.out.print(" m =" + m);
            System.out.print("Para 1 (Num levels) : ");
            for (int i2 : this.paraSpace1) {
                System.out.print(i2 + ", ");
            }
            System.out.print("\nPara 2 (Num features) : ");
            for (int i2 : this.paraSpace2) {
                System.out.print(i2 + ", ");
            }
            System.out.print("\nPara 3 (Num trees) : ");
            for (int i2 : this.paraSpace3) {
                System.out.print(i2 + ", ");
            }
        }
    }

    public void tuneRandomForest(Instances train) throws Exception {
        int bestNumTrees;
        int bestNumFeatures;
        ClassifierResults tempResults;
        this.paras = new int[3];
        int folds = MAX_FOLDS;
        if (folds > train.numInstances()) {
            folds = train.numInstances();
        }
        double minErr = 1.0;
        this.setSeed(this.rng.nextInt());
        Instances trainCopy = new Instances(train);
        CrossValidator cv = new CrossValidator();
        cv.setSeed(this.seed);
        cv.setNumFolds(folds);
        cv.buildFolds(trainCopy);
        ArrayList<TunedSVM.ResultsHolder> ties = new ArrayList<TunedSVM.ResultsHolder>();
        int count = 0;
        OutFile temp = null;
        for (int p1 : this.paraSpace1) {
            for (int p2 : this.paraSpace2) {
                for (int p3 : this.paraSpace3) {
                    File f;
                    if (this.saveEachParaAcc && (f = new File(this.resultsPath + ++count + ".csv")).exists()) {
                        if (CollateResults.validateSingleFoldFile(this.resultsPath + count + ".csv")) continue;
                        System.out.println("Deleting file " + this.resultsPath + count + ".csv because size =" + f.length());
                    }
                    TunedRandomForest model = new TunedRandomForest();
                    model.setMaxDepth(p1);
                    model.setNumFeatures(p2);
                    model.setNumTrees(p3);
                    model.tuneParameters = false;
                    model.estimateAcc = false;
                    model.setSeed(count);
                    tempResults = cv.crossValidateWithStats(model, trainCopy);
                    tempResults.setName("RandFPara" + count);
                    tempResults.setParas("maxDepth," + p1 + ",numFeatures," + p2 + ",numTrees," + p3);
                    double e = 1.0 - tempResults.acc;
                    if (this.m_Debug) {
                        System.out.println("Depth=" + p1 + ",Features" + p2 + ",Trees=" + p3 + " Acc = " + (1.0 - e));
                    }
                    this.accuracy.add(tempResults.acc);
                    if (this.saveEachParaAcc) {
                        temp = new OutFile(this.resultsPath + count + ".csv");
                        temp.writeLine(tempResults.writeResultsFileToString());
                        temp.closeFile();
                        File f2 = new File(this.resultsPath + count + ".csv");
                        if (!f2.exists()) continue;
                        f2.setWritable(true, false);
                        continue;
                    }
                    if (e < minErr) {
                        minErr = e;
                        ties = new ArrayList();
                        ties.add(new TunedSVM.ResultsHolder(p1, p2, p3, tempResults));
                        continue;
                    }
                    if (e != minErr) continue;
                    ties.add(new TunedSVM.ResultsHolder(p1, p2, p3, tempResults));
                }
            }
        }
        minErr = 1.0;
        if (this.saveEachParaAcc) {
            int missing = 0;
            for (int p1 : this.paraSpace1) {
                for (int p2 : this.paraSpace2) {
                    for (int p3 : this.paraSpace3) {
                        File f = new File(this.resultsPath + count + ".csv");
                        if (f.exists() && f.length() > 0L) continue;
                        ++missing;
                    }
                }
            }
            if (missing == 0) {
                this.combinedBuildTime = 0L;
                count = 0;
                for (int p1 : this.paraSpace1) {
                    for (int p2 : this.paraSpace2) {
                        for (int p3 : this.paraSpace3) {
                            tempResults = new ClassifierResults();
                            tempResults.loadFromFile(this.resultsPath + ++count + ".csv");
                            this.combinedBuildTime += tempResults.buildTime;
                            double e = 1.0 - tempResults.acc;
                            if (e < minErr) {
                                minErr = e;
                                ties = new ArrayList();
                                ties.add(new TunedSVM.ResultsHolder(p1, p2, p3, tempResults));
                            } else if (e == minErr) {
                                ties.add(new TunedSVM.ResultsHolder(p1, p2, p3, tempResults));
                            }
                            File f = new File(this.resultsPath + count + ".csv");
                            if (f.delete()) continue;
                            System.out.println("DELETE FAILED " + this.resultsPath + count + ".csv");
                        }
                    }
                }
                TunedSVM.ResultsHolder best = (TunedSVM.ResultsHolder)ties.get(this.rng.nextInt(ties.size()));
                int bestNumLevels = (int)best.x;
                bestNumFeatures = (int)best.y;
                bestNumTrees = (int)best.z;
                this.paras[0] = bestNumLevels;
                this.paras[1] = bestNumFeatures;
                this.paras[2] = bestNumTrees;
                this.setMaxDepth(bestNumLevels);
                this.setNumFeatures(bestNumFeatures);
                this.setNumTrees(bestNumTrees);
                this.res = best.res;
                if (this.m_Debug) {
                    System.out.println("Bestnum levels =" + bestNumLevels + " best num features = " + bestNumFeatures + " best num trees =" + bestNumTrees + " best train acc = " + this.res.acc);
                }
            } else {
                System.out.println(this.resultsPath + " error: missing  =" + missing + " parameter values");
            }
        } else {
            TunedSVM.ResultsHolder best = (TunedSVM.ResultsHolder)ties.get(this.rng.nextInt(ties.size()));
            int bestNumLevels = (int)best.x;
            bestNumFeatures = (int)best.y;
            bestNumTrees = (int)best.z;
            this.paras[0] = bestNumLevels;
            this.paras[1] = bestNumFeatures;
            this.paras[2] = bestNumTrees;
            this.setMaxDepth(bestNumLevels);
            this.setNumFeatures(bestNumFeatures);
            this.setNumTrees(bestNumTrees);
            this.res = best.res;
        }
    }

    @Override
    public void buildClassifier(Instances data) throws Exception {
        long startTime = System.currentTimeMillis();
        this.getCapabilities().testWithFail(data);
        data = new Instances(data);
        data.deleteWithMissingClass();
        int folds = 10;
        if (folds > data.numInstances()) {
            folds = data.numInstances();
        }
        super.setSeed(this.seed);
        if (this.tuneParameters) {
            if (this.paraSpace1 == null) {
                this.setStandardParaSearchSpace(data.numAttributes() - 1);
            }
            this.tuneRandomForest(data);
        } else {
            this.setNumFeatures(Math.max(1, (int)Math.sqrt(data.numAttributes() - 1)));
        }
        this.m_bagger = new EnhancedBagging();
        RandomTree rTree = new RandomTree();
        if (this.m_numFeatures > data.numAttributes() - 1) {
            this.m_numFeatures = data.numAttributes() - 1;
        }
        if (this.m_MaxDepth > data.numAttributes() - 1) {
            this.m_MaxDepth = 0;
        }
        this.m_KValue = this.m_numFeatures;
        rTree.setKValue(this.m_KValue);
        rTree.setMaxDepth(this.getMaxDepth());
        this.m_bagger.setClassifier(rTree);
        this.m_bagger.setSeed(this.seed);
        this.m_bagger.setNumIterations(this.m_numTrees);
        this.m_bagger.setCalcOutOfBag(true);
        this.m_bagger.setNumExecutionSlots(this.m_numExecutionSlots);
        this.m_bagger.buildClassifier(data);
        if (this.estimateAcc) {
            if (this.crossValidate) {
                RandomForest t = new RandomForest();
                t.setNumFeatures(this.getNumFeatures());
                t.setNumTrees(this.getNumTrees());
                t.setSeed(this.seed);
                CrossValidator cv = new CrossValidator();
                cv.setSeed(this.seed);
                cv.setNumFolds(folds);
                cv.buildFolds(data);
                this.res = cv.crossValidateWithStats(t, data);
                if (this.m_Debug) {
                    System.out.println("In cross  validate");
                    System.out.println(this.getParameters());
                }
            } else {
                this.res.acc = 1.0 - this.measureOutOfBagError();
                System.out.println("BAGGER CLASS = " + this.m_bagger.getClass().getName());
                ((EnhancedBagging)this.m_bagger).findOOBProbabilities();
                double[][] OOBPredictions = ((EnhancedBagging)this.m_bagger).OOBProbabilities;
                for (int i = 0; i < data.numInstances(); ++i) {
                    this.res.storeSingleResult(data.instance(i).classValue(), OOBPredictions[i]);
                }
            }
        }
        this.res.buildTime = System.currentTimeMillis() - startTime;
        if (this.trainPath != "") {
            OutFile f = new OutFile(this.trainPath);
            f.writeLine(data.relationName() + ",TunedRandF,Train");
            f.writeLine(this.getParameters());
            f.writeLine(this.res.acc + "");
            f.writeLine(this.res.writeInstancePredictions());
        }
    }

    public void addTrees(int n, Instances data) throws Exception {
        EnhancedBagging newTrees = new EnhancedBagging();
        RandomTree rTree = new RandomTree();
        this.m_KValue = this.m_numFeatures;
        rTree.setKValue(this.m_KValue);
        rTree.setMaxDepth(this.getMaxDepth());
        Random r = new Random();
        newTrees.setSeed(r.nextInt());
        newTrees.setClassifier(rTree);
        newTrees.setNumIterations(n);
        newTrees.setCalcOutOfBag(true);
        newTrees.setNumExecutionSlots(this.m_numExecutionSlots);
        newTrees.buildClassifier(data);
        newTrees.findOOBProbabilities();
        this.m_bagger.aggregate(newTrees);
        this.m_bagger.finalizeAggregation();
        this.m_numTrees += n;
        this.m_bagger.setNumIterations(this.m_numTrees);
        ((EnhancedBagging)this.m_bagger).mergeBaggers(newTrees);
    }

    public double getBaggingPercent() {
        return this.m_bagger.getBagSizePercent();
    }

    public double findOOBError() throws Exception {
        ((EnhancedBagging)this.m_bagger).findOOBProbabilities();
        return ((EnhancedBagging)this.m_bagger).findOOBError();
    }

    public double[][] findOOBProbabilities() throws Exception {
        ((EnhancedBagging)this.m_bagger).findOOBProbabilities();
        return ((EnhancedBagging)this.m_bagger).OOBProbabilities;
    }

    public double[][] getOBProbabilities() throws Exception {
        return ((EnhancedBagging)this.m_bagger).OOBProbabilities;
    }

    public static void jamesltests() {
        System.out.println("ranftestsWITHCHANGES");
        String dataset = "ItalyPowerDemand";
        Instances train = ClassifierTools.loadData("c:/tsc problems/" + dataset + "/" + dataset + "_TRAIN");
        Instances test = ClassifierTools.loadData("c:/tsc problems/" + dataset + "/" + dataset + "_TEST");
        int rs = 50;
        double[] trainAccs = new double[rs];
        double[] testAccs = new double[rs];
        double trainAcc = 0.0;
        double testAcc = 0.0;
        for (int r = 0; r < rs; ++r) {
            Instances[] data = InstanceTools.resampleTrainAndTestInstances(train, test, r);
            TunedRandomForest ranF = new TunedRandomForest();
            ranF.setCrossValidate(true);
            ranF.setEstimateAcc(true);
            try {
                ranF.buildClassifier(data[0]);
            }
            catch (Exception ex) {
                Logger.getLogger(TunedRandomForest.class.getName()).log(Level.SEVERE, null, ex);
            }
            trainAccs[r] = ranF.res.acc;
            trainAcc += trainAccs[r];
            testAccs[r] = ClassifierTools.accuracy(data[1], ranF);
            testAcc += testAccs[r];
            System.out.print(".");
        }
        System.out.println("\nacc=" + (trainAcc /= (double)rs));
        System.out.println(Arrays.toString(trainAccs));
        System.out.println("\nacc=" + (testAcc /= (double)rs));
        System.out.println(Arrays.toString(testAccs));
    }

    public static void main(String[] args) {
        TunedRandomForest randF = new TunedRandomForest();
        randF.m_Debug = true;
        randF.setStandardParaSearchSpace(200);
        System.exit(0);
        DecimalFormat df = new DecimalFormat("##.###");
        try {
            String dset = "balloons";
            Instances all = ClassifierTools.loadData("C:\\Users\\ajb\\Dropbox\\UCI Problems\\" + dset + "\\" + dset);
            Instances[] split = InstanceTools.resampleInstances(all, 1L, 0.5);
            TunedRandomForest rf = new TunedRandomForest();
            rf.debug(true);
            rf.tuneParameters(true);
            rf.buildClassifier(split[0]);
            System.out.println(" bag percent =" + rf.getBaggingPercent() + " OOB error " + rf.measureOutOfBagError());
        }
        catch (Exception e) {
            System.out.println("Exception " + e);
            e.printStackTrace();
            System.exit(0);
        }
    }

    protected static class EnhancedBagging
    extends Bagging {
        double[][] OOBProbabilities;
        int[] counts;

        protected EnhancedBagging() {
        }

        @Override
        public void buildClassifier(Instances data) throws Exception {
            super.buildClassifier(data);
            this.m_data = data;
        }

        public void mergeBaggers(EnhancedBagging other) {
            int i;
            for (int i2 = 0; i2 < this.m_data.numInstances(); ++i2) {
                int j = 0;
                while (j < this.m_data.numClasses()) {
                    this.OOBProbabilities[i2][j] = (double)this.counts[i2] * this.OOBProbabilities[i2][j] + (double)other.counts[i2] * other.OOBProbabilities[i2][j];
                    double[] dArray = this.OOBProbabilities[i2];
                    int n = j++;
                    dArray[n] = dArray[n] / (double)(this.counts[i2] + other.counts[i2]);
                }
                this.counts[i2] = this.counts[i2] + other.counts[i2];
            }
            boolean[][] inBags = new boolean[this.m_inBag.length + other.m_inBag.length][];
            for (i = 0; i < this.m_inBag.length; ++i) {
                inBags[i] = this.m_inBag[i];
            }
            for (i = 0; i < other.m_inBag.length; ++i) {
                inBags[this.m_inBag.length + i] = other.m_inBag[i];
            }
            this.m_inBag = inBags;
            this.findOOBError();
        }

        public void findOOBProbabilities() throws Exception {
            this.OOBProbabilities = new double[this.m_data.numInstances()][this.m_data.numClasses()];
            this.counts = new int[this.m_data.numInstances()];
            for (int i = 0; i < this.m_data.numInstances(); ++i) {
                for (int j = 0; j < this.m_Classifiers.length; ++j) {
                    if (this.m_inBag[j][i]) continue;
                    int n = i;
                    this.counts[n] = this.counts[n] + 1;
                    double[] newProbs = this.m_Classifiers[j].distributionForInstance(this.m_data.instance(i));
                    for (int k = 0; k < this.m_data.numClasses(); ++k) {
                        double[] dArray = this.OOBProbabilities[i];
                        int n2 = k;
                        dArray[n2] = dArray[n2] + newProbs[k];
                    }
                }
                int k = 0;
                while (k < this.m_data.numClasses()) {
                    double[] dArray = this.OOBProbabilities[i];
                    int n = k++;
                    dArray[n] = dArray[n] / (double)this.counts[i];
                }
            }
        }

        public double findOOBError() {
            double correct = 0.0;
            for (int i = 0; i < this.m_data.numInstances(); ++i) {
                double[] probs = this.OOBProbabilities[i];
                int vote = 0;
                for (int j = 1; j < probs.length; ++j) {
                    if (!(probs[vote] < probs[j])) continue;
                    vote = j;
                }
                if (this.m_data.instance(i).classValue() != (double)vote) continue;
                correct += 1.0;
            }
            this.m_OutOfBagError = 1.0 - correct / (double)this.m_data.numInstances();
            return this.m_OutOfBagError;
        }
    }
}

