/*
 * Decompiled with CFR 0.152.
 */
package vector_classifiers;

import development.CollateResults;
import fileIO.OutFile;
import java.io.File;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Random;
import java.util.logging.Level;
import java.util.logging.Logger;
import timeseriesweka.classifiers.ParameterSplittable;
import utilities.ClassifierResults;
import utilities.ClassifierTools;
import utilities.CrossValidator;
import utilities.InstanceTools;
import utilities.SaveParameterInfo;
import utilities.TrainAccuracyEstimate;
import vector_classifiers.SaveEachParameter;
import vector_classifiers.TunedSVM;
import weka.classifiers.meta.RotationForest;
import weka.core.Instances;

public class TunedRotationForest
extends RotationForest
implements SaveParameterInfo,
TrainAccuracyEstimate,
SaveEachParameter,
ParameterSplittable {
    protected boolean tuneParameters = true;
    protected int[] paraSpace1;
    protected int[] paraSpace2;
    protected int[] paraSpace3;
    protected int[] paras;
    protected String trainPath = "";
    protected boolean debug = false;
    protected boolean findTrainAcc = true;
    protected int seed;
    protected Random rng;
    protected ArrayList<Double> accuracy;
    protected ArrayList<Double> buildTimes;
    protected ClassifierResults res = new ClassifierResults();
    protected long combinedBuildTime;
    protected static int MAX_FOLDS = 10;
    protected String resultsPath;
    protected boolean saveEachParaAcc = false;
    private static int MAX_PER_PARA = 10;
    private boolean buildFromPartial = false;

    public TunedRotationForest() {
        this.setNumIterations(200);
        this.rng = new Random();
        this.seed = 0;
        this.accuracy = new ArrayList();
    }

    @Override
    public String getParameters() {
        String result = "BuildTime," + this.res.buildTime + ",CVAcc," + this.res.acc + ",RemovePercent," + this.getRemovedPercentage() + ",NumFeatures," + this.getMaxGroup() + ",numTrees," + this.getNumIterations();
        for (double d : this.accuracy) {
            result = result + "," + d;
        }
        return result;
    }

    @Override
    public void setParamSearch(boolean b) {
        this.tuneParameters = b;
    }

    @Override
    public void setPathToSaveParameters(String r) {
        this.resultsPath = r;
        this.setSaveEachParaAcc(true);
    }

    @Override
    public void setSaveEachParaAcc(boolean b) {
        this.saveEachParaAcc = b;
    }

    @Override
    public String getParas() {
        return this.getParameters();
    }

    @Override
    public double getAcc() {
        return this.res.acc;
    }

    @Override
    public void setParametersFromIndex(int x) {
        this.tuneParameters = false;
        this.paras = new int[3];
        int numPerGroupIndex = x / (MAX_PER_PARA * MAX_PER_PARA);
        int removePercentIndex = x % MAX_PER_PARA / MAX_PER_PARA;
        int numTreesIndex = x % MAX_PER_PARA;
        this.paras[0] = 3 + numPerGroupIndex;
        this.paras[1] = 100 / MAX_PER_PARA * removePercentIndex;
        this.paras[2] = numTreesIndex == 0 ? 10 : 50 * numTreesIndex;
        this.setMaxGroup(this.paras[0]);
        this.setMinGroup(this.paras[0]);
        this.setRemovedPercentage(this.paras[1]);
        this.setNumIterations(this.paras[2]);
    }

    @Override
    public void setSeed(int s) {
        super.setSeed(s);
        this.seed = s;
        this.rng = new Random();
        this.rng.setSeed(this.seed);
    }

    public void debug(boolean b) {
        this.debug = b;
    }

    public void justBuildTheClassifier() {
        this.estimateAccFromTrain(false);
        this.tuneParameters(false);
        this.debug = false;
    }

    public void estimateAccFromTrain(boolean b) {
        this.findTrainAcc = b;
    }

    public void tuneParameters(boolean b) {
        this.tuneParameters = b;
    }

    public void setNumTreesRange(int[] d) {
        this.paraSpace2 = d;
    }

    public void setNumFeaturesRange(int[] d) {
        this.paraSpace1 = d;
    }

    @Override
    public void writeCVTrainToFile(String train) {
        this.trainPath = train;
        this.findTrainAcc = true;
    }

    @Override
    public boolean findsTrainAccuracyEstimate() {
        return this.findTrainAcc;
    }

    @Override
    public ClassifierResults getTrainResults() {
        return this.res;
    }

    protected final void setStandardParaSearchSpace(int m) {
        int i;
        int numParas = MAX_PER_PARA;
        if (m < numParas) {
            numParas = m;
        }
        this.paraSpace1 = new int[numParas];
        this.paraSpace1[0] = 3;
        for (i = 1; i < numParas; ++i) {
            this.paraSpace1[i] = 3 + i;
        }
        this.paraSpace2 = new int[MAX_PER_PARA];
        this.paraSpace2[0] = 0;
        for (i = 1; i < MAX_PER_PARA; ++i) {
            this.paraSpace2[i] = 100 / MAX_PER_PARA * i;
        }
        this.paraSpace3 = new int[MAX_PER_PARA];
        for (i = 0; i < this.paraSpace3.length; ++i) {
            this.paraSpace3[i] = 50 * (i + 1);
        }
    }

    public void tuneRotationForest(Instances train) throws Exception {
        int bestNumTrees;
        int bestRemovePercent;
        ClassifierResults tempResults;
        this.paras = new int[3];
        int folds = MAX_FOLDS;
        if (folds > train.numInstances()) {
            folds = train.numInstances();
        }
        double minErr = 1.0;
        this.setSeed(this.rng.nextInt());
        Instances trainCopy = new Instances(train);
        CrossValidator cv = new CrossValidator();
        cv.setSeed(this.seed);
        cv.setNumFolds(folds);
        cv.buildFolds(trainCopy);
        ArrayList<TunedSVM.ResultsHolder> ties = new ArrayList<TunedSVM.ResultsHolder>();
        int count = 0;
        OutFile temp = null;
        for (int p1 : this.paraSpace1) {
            for (int p2 : this.paraSpace2) {
                for (int p3 : this.paraSpace3) {
                    File file;
                    if (this.saveEachParaAcc && (file = new File(this.resultsPath + ++count + ".csv")).exists()) {
                        if (CollateResults.validateSingleFoldFile(this.resultsPath + count + ".csv")) continue;
                        System.out.println("Deleted file " + this.resultsPath + count + ".csv because size =" + file.length());
                    }
                    TunedRotationForest tunedRotationForest = new TunedRotationForest();
                    tunedRotationForest.tuneParameters(false);
                    tunedRotationForest.findTrainAcc = false;
                    tunedRotationForest.setMaxGroup(p1);
                    tunedRotationForest.setMinGroup(p1);
                    tunedRotationForest.setRemovedPercentage(p2);
                    tunedRotationForest.setNumIterations(p3);
                    tunedRotationForest.paras = new int[3];
                    tempResults = cv.crossValidateWithStats(tunedRotationForest, trainCopy);
                    tempResults.setName("RotFPara" + count);
                    tempResults.setParas("NumFeatures," + p1 + ",RemovePercent," + p2 + ",numTrees," + p3);
                    double e = 1.0 - tempResults.acc;
                    if (this.debug) {
                        System.out.println("Group size=" + p1 + ",Remove Prop=" + p2 + ",Trees=" + p3 + ", Acc = " + (1.0 - e));
                    }
                    this.accuracy.add(tempResults.acc);
                    if (this.saveEachParaAcc) {
                        temp = new OutFile(this.resultsPath + count + ".csv");
                        temp.writeLine(tempResults.writeResultsFileToString());
                        temp.closeFile();
                        File f2 = new File(this.resultsPath + count + ".csv");
                        if (!f2.exists()) continue;
                        f2.setWritable(true, false);
                        continue;
                    }
                    if (e < minErr) {
                        minErr = e;
                        ties = new ArrayList();
                        ties.add(new TunedSVM.ResultsHolder(p1, p2, p3, tempResults));
                        continue;
                    }
                    if (e != minErr) continue;
                    ties.add(new TunedSVM.ResultsHolder(p1, p2, p3, tempResults));
                }
            }
        }
        minErr = 1.0;
        if (this.saveEachParaAcc) {
            int missing = 0;
            for (int p1 : this.paraSpace1) {
                for (int p2 : this.paraSpace2) {
                    int[] nArray = this.paraSpace3;
                    int n = nArray.length;
                    for (int i = 0; i < n; ++i) {
                        int p3 = nArray[i];
                        File f = new File(this.resultsPath + count + ".csv");
                        if (f.exists() && f.length() > 0L) continue;
                        ++missing;
                    }
                }
            }
            if (missing == 0) {
                this.combinedBuildTime = 0L;
                count = 0;
                for (int p1 : this.paraSpace1) {
                    for (int p2 : this.paraSpace2) {
                        for (int p3 : this.paraSpace3) {
                            tempResults = new ClassifierResults();
                            if (!new File(this.resultsPath + ++count + ".csv").exists()) continue;
                            tempResults.loadFromFile(this.resultsPath + count + ".csv");
                            this.combinedBuildTime += tempResults.buildTime;
                            double e = 1.0 - tempResults.acc;
                            if (e < minErr) {
                                minErr = e;
                                ties = new ArrayList();
                                ties.add(new TunedSVM.ResultsHolder(p1, p2, p3, tempResults));
                                continue;
                            }
                            if (e != minErr) continue;
                            ties.add(new TunedSVM.ResultsHolder(p1, p2, p3, tempResults));
                        }
                    }
                }
                TunedSVM.ResultsHolder best = (TunedSVM.ResultsHolder)ties.get(this.rng.nextInt(ties.size()));
                int bestNumAtts = (int)best.x;
                bestRemovePercent = (int)best.y;
                bestNumTrees = (int)best.z;
                this.paras[0] = bestNumAtts;
                this.paras[1] = bestRemovePercent;
                this.paras[2] = bestNumTrees;
                this.setNumIterations(bestNumTrees);
                this.setRemovedPercentage(bestRemovePercent);
                this.setMaxGroup(bestNumAtts);
                this.setMinGroup(bestNumAtts);
                this.res = best.res;
                if (this.debug) {
                    System.out.println("Bestnum in group =" + bestNumAtts + "  best num trees =" + bestNumTrees + " best train acc = " + this.res.acc);
                }
                count = 1;
                for (int p1 : this.paraSpace1) {
                    for (int n : this.paraSpace2) {
                        for (int p3 : this.paraSpace3) {
                            File f = new File(this.resultsPath + count + ".csv");
                            boolean deleted = f.delete();
                            if (!deleted) {
                                System.out.println("DELETE FAILED " + this.resultsPath + count + ".csv");
                                f.setReadable(true);
                                f.setWritable(true);
                                deleted = f.delete();
                                if (!deleted) {
                                    System.out.println("\t DELETE FAILED AGAIN" + this.resultsPath + count + ".csv");
                                }
                            }
                            ++count;
                        }
                    }
                }
            } else {
                System.out.println(this.resultsPath + " error: missing  =" + missing + " parameter values");
            }
        } else {
            TunedSVM.ResultsHolder best = (TunedSVM.ResultsHolder)ties.get(this.rng.nextInt(ties.size()));
            int bestNumAtts = (int)best.x;
            bestRemovePercent = (int)best.y;
            bestNumTrees = (int)best.z;
            this.paras[0] = bestNumAtts;
            this.paras[1] = bestRemovePercent;
            this.paras[2] = bestNumTrees;
            this.setNumIterations(bestNumTrees);
            this.setRemovedPercentage(bestRemovePercent);
            this.setMaxGroup(bestNumAtts);
            this.setMinGroup(bestNumAtts);
            this.res = best.res;
        }
    }

    private void setParasFromPartiallyCompleteSearch() throws Exception {
        this.paras = new int[3];
        this.combinedBuildTime = 0L;
        ArrayList<TunedSVM.ResultsHolder> ties = new ArrayList<TunedSVM.ResultsHolder>();
        int count = 0;
        int present = 0;
        double minErr = 1.0;
        for (int p1 : this.paraSpace1) {
            for (int p2 : this.paraSpace2) {
                for (int p3 : this.paraSpace3) {
                    ClassifierResults tempResults = new ClassifierResults();
                    if (!new File(this.resultsPath + ++count + ".csv").exists()) continue;
                    ++present;
                    tempResults.loadFromFile(this.resultsPath + count + ".csv");
                    this.combinedBuildTime += tempResults.buildTime;
                    double e = 1.0 - tempResults.acc;
                    if (e < minErr) {
                        minErr = e;
                        ties = new ArrayList();
                        ties.add(new TunedSVM.ResultsHolder(p1, p2, p3, tempResults));
                        continue;
                    }
                    if (e != minErr) continue;
                    ties.add(new TunedSVM.ResultsHolder(p1, p2, p3, tempResults));
                }
            }
        }
        if (present <= 0) {
            throw new Exception("Error, no parameter files for " + this.resultsPath);
        }
        System.out.println("Number of paras = " + present);
        System.out.println("Number of best = " + ties.size());
        TunedSVM.ResultsHolder best = (TunedSVM.ResultsHolder)ties.get(this.rng.nextInt(ties.size()));
        int bestNumAtts = (int)best.x;
        int bestRemovePercent = (int)best.y;
        int bestNumTrees = (int)best.z;
        this.paras[0] = bestNumAtts;
        this.paras[1] = bestRemovePercent;
        this.paras[2] = bestNumTrees;
        this.setNumIterations(bestNumTrees);
        this.setRemovedPercentage(bestRemovePercent);
        this.setMaxGroup(bestNumAtts);
        this.setMinGroup(bestNumAtts);
        this.res = best.res;
    }

    @Override
    public void buildClassifier(Instances data) throws Exception {
        long startTime = System.currentTimeMillis();
        int folds = MAX_FOLDS;
        if (folds > data.numInstances()) {
            folds = data.numInstances();
        }
        this.getCapabilities().testWithFail(data);
        data = new Instances(data);
        data.deleteWithMissingClass();
        super.setSeed(this.seed);
        if (this.buildFromPartial) {
            if (this.paraSpace1 == null) {
                this.setStandardParaSearchSpace(data.numAttributes() - 1);
            }
            this.setParasFromPartiallyCompleteSearch();
        } else if (this.tuneParameters) {
            if (this.paraSpace1 == null) {
                this.setStandardParaSearchSpace(data.numAttributes() - 1);
            }
            this.tuneRotationForest(data);
        } else if (this.findTrainAcc) {
            RotationForest t = new RotationForest();
            t.setMaxGroup(this.getMaxGroup());
            t.setMinGroup(this.getMinGroup());
            t.setNumIterations(this.getNumIterations());
            t.setSeed(this.seed);
            CrossValidator cv = new CrossValidator();
            cv.setSeed(this.seed);
            cv.setNumFolds(folds);
            cv.buildFolds(data);
            this.res = cv.crossValidateWithStats(t, data);
        }
        super.buildClassifier(data);
        this.res.buildTime = System.currentTimeMillis() - startTime;
        if (this.trainPath != "") {
            OutFile f = new OutFile(this.trainPath);
            f.writeLine(data.relationName() + ",TunedRotF,Train");
            f.writeLine(this.getParameters());
            f.writeLine(this.res.acc + "");
            f.writeString(this.res.writeInstancePredictions());
            f.closeFile();
            File x = new File(this.trainPath);
            x.setWritable(true, false);
        }
    }

    public static void jamesltests() {
        System.out.println("rotftests");
        String dataset = "ItalyPowerDemand";
        Instances train = ClassifierTools.loadData("c:/tsc problems/" + dataset + "/" + dataset + "_TRAIN");
        Instances test = ClassifierTools.loadData("c:/tsc problems/" + dataset + "/" + dataset + "_TEST");
        int rs = 50;
        double[] trainAccs = new double[rs];
        double[] testAccs = new double[rs];
        double trainAcc = 0.0;
        double testAcc = 0.0;
        for (int r = 0; r < rs; ++r) {
            Instances[] data = InstanceTools.resampleTrainAndTestInstances(train, test, r);
            TunedRotationForest rotF = new TunedRotationForest();
            rotF.estimateAccFromTrain(true);
            try {
                rotF.buildClassifier(data[0]);
            }
            catch (Exception ex) {
                Logger.getLogger(TunedRotationForest.class.getName()).log(Level.SEVERE, null, ex);
            }
            trainAccs[r] = rotF.res.acc;
            trainAcc += trainAccs[r];
            testAccs[r] = ClassifierTools.accuracy(data[1], rotF);
            testAcc += testAccs[r];
            System.out.print(".");
        }
        System.out.println("\nacc=" + (trainAcc /= (double)rs));
        System.out.println(Arrays.toString(trainAccs));
        System.out.println("\nacc=" + (testAcc /= (double)rs));
        System.out.println(Arrays.toString(testAccs));
    }

    public static void main(String[] args) {
        TunedRotationForest.cheatOnMNIST();
        System.exit(0);
        DecimalFormat df = new DecimalFormat("##.###");
        try {
            String dset = "balloons";
            Instances all = ClassifierTools.loadData("C:\\Users\\ajb\\Dropbox\\UCI Problems\\" + dset + "\\" + dset);
            Instances[] split = InstanceTools.resampleInstances(all, 1L, 0.5);
            TunedRotationForest rf = new TunedRotationForest();
            rf.debug(true);
            rf.tuneParameters(true);
            rf.buildClassifier(split[0]);
        }
        catch (Exception e) {
            System.out.println("Exception " + e);
            e.printStackTrace();
            System.exit(0);
        }
    }

    public static void cheatOnMNIST() {
        Instances train = ClassifierTools.loadData("\\\\cmptscsvr.cmp.uea.ac.uk\\ueatsc\\Data\\LargeProblems\\MNIST\\MNIST_TRAIN");
        Instances test = ClassifierTools.loadData("\\\\cmptscsvr.cmp.uea.ac.uk\\ueatsc\\Data\\LargeProblems\\MNIST\\MNIST_TEST");
        RotationForest rf = new RotationForest();
        System.out.println("Data loaded ......");
        double a = ClassifierTools.singleTrainTestSplitAccuracy(rf, train, test);
        System.out.println("Trees =10 acc = " + a);
        for (int trees = 50; trees <= 500; trees += 50) {
            rf.setNumIterations(trees);
            a = ClassifierTools.singleTrainTestSplitAccuracy(rf, train, test);
            System.out.println("Trees =" + trees + " acc = " + a);
        }
    }
}

