/*
 * Decompiled with CFR 0.152.
 */
package vector_classifiers;

import development.CollateResults;
import fileIO.OutFile;
import java.io.File;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.Random;
import timeseriesweka.classifiers.ParameterSplittable;
import utilities.ClassifierResults;
import utilities.ClassifierTools;
import utilities.CrossValidator;
import utilities.InstanceTools;
import utilities.SaveParameterInfo;
import utilities.TrainAccuracyEstimate;
import vector_classifiers.SaveEachParameter;
import weka.classifiers.functions.MultilayerPerceptron;
import weka.core.Instances;

public class TunedTwoLayerMLP
extends MultilayerPerceptron
implements SaveParameterInfo,
TrainAccuracyEstimate,
SaveEachParameter,
ParameterSplittable {
    protected boolean tuneParameters = true;
    protected String[] paraSpace1;
    protected double[] paraSpace2;
    protected double[] paraSpace3;
    protected boolean[] paraSpace4;
    protected String trainPath = "";
    protected boolean debug = false;
    protected boolean findTrainAcc = true;
    protected int seed = 0;
    protected Random rng;
    protected ArrayList<Double> accuracy;
    protected ArrayList<Double> buildTimes;
    protected ClassifierResults res = new ClassifierResults();
    protected long combinedBuildTime;
    protected static int MAX_FOLDS = 10;
    protected String resultsPath;
    protected boolean saveEachParaAcc = false;
    private static int MAX_PER_PARA = 10;

    public TunedTwoLayerMLP() {
        this.rng = new Random();
        this.accuracy = new ArrayList();
    }

    @Override
    public String getParameters() {
        String result = "BuildTime," + this.res.buildTime + ",CVAcc," + this.res.acc + ",Nodes," + this.getHiddenLayers() + ",LearningRate," + this.getLearningRate() + ",Momentum," + this.getMomentum() + ",Decay," + this.getDecay();
        for (double d : this.accuracy) {
            result = result + "," + d;
        }
        return result;
    }

    @Override
    public void setParamSearch(boolean b) {
        this.tuneParameters = b;
    }

    @Override
    public void setPathToSaveParameters(String r) {
        this.resultsPath = r;
        this.setSaveEachParaAcc(true);
    }

    @Override
    public void setSaveEachParaAcc(boolean b) {
        this.saveEachParaAcc = b;
    }

    @Override
    public String getParas() {
        return this.getParameters();
    }

    @Override
    public double getAcc() {
        return this.res.acc;
    }

    @Override
    public void setParametersFromIndex(int x) {
        int i;
        this.paraSpace1 = new String[8];
        this.paraSpace1[0] = "a";
        this.paraSpace1[1] = "i";
        this.paraSpace1[2] = "o";
        this.paraSpace1[3] = "t";
        this.paraSpace1[4] = "a,a";
        this.paraSpace1[5] = "i,i";
        this.paraSpace1[6] = "o,o";
        this.paraSpace1[7] = "t,t";
        this.paraSpace2 = new double[8];
        for (i = 0; i < this.paraSpace2.length; ++i) {
            this.paraSpace2[i] = 1.0 / Math.pow(2.0, i);
        }
        this.paraSpace3 = new double[8];
        for (i = 0; i < this.paraSpace3.length; ++i) {
            this.paraSpace3[i] = (double)i / 10.0;
        }
        this.paraSpace4 = new boolean[2];
        this.paraSpace4[0] = true;
        this.paraSpace4[1] = false;
        String p1 = "0";
        p1 = x <= 200 ? "0,0" : (x <= 400 ? "a,a" : (x <= 600 ? "i,i" : (x <= 800 ? "o,o" : "t")));
        int t = (x - 1) % 200;
        boolean p4 = t > 100;
        t = (x - 1) % 100;
        double p2 = Math.pow(2.0, t % 10);
        p2 = 1.0 / p2;
        double p3 = t / 10;
        this.setHiddenLayers(p1);
        this.setLearningRate(p2);
        this.setMomentum(p3 /= 10.0);
        this.setDecay(p4);
        if (this.debug) {
            System.out.println("input =" + x + " Paras =" + p1 + "," + p2 + "," + p3 + "," + p4);
        }
    }

    @Override
    public void setSeed(int s) {
        super.setSeed(s);
        this.seed = s;
        this.rng = new Random();
        this.rng.setSeed(this.seed);
    }

    public void debug(boolean b) {
        this.debug = b;
    }

    public void justBuildTheClassifier() {
        this.estimateAccFromTrain(false);
        this.tuneParameters(false);
        this.debug = false;
    }

    public void estimateAccFromTrain(boolean b) {
        this.findTrainAcc = b;
    }

    public void tuneParameters(boolean b) {
        this.tuneParameters = b;
    }

    @Override
    public void writeCVTrainToFile(String train) {
        this.trainPath = train;
        this.findTrainAcc = true;
    }

    @Override
    public boolean findsTrainAccuracyEstimate() {
        return this.findTrainAcc;
    }

    @Override
    public ClassifierResults getTrainResults() {
        return this.res;
    }

    protected final void setStandardParaSearchSpace(int m) {
        int i;
        this.paraSpace1 = new String[8];
        this.paraSpace1[0] = "a";
        this.paraSpace1[1] = "i";
        this.paraSpace1[2] = "o";
        this.paraSpace1[3] = "t";
        this.paraSpace1[4] = "a,a";
        this.paraSpace1[5] = "i,i";
        this.paraSpace1[6] = "o,o";
        this.paraSpace1[7] = "t,t";
        this.paraSpace2 = new double[8];
        for (i = 0; i < this.paraSpace2.length; ++i) {
            this.paraSpace2[i] = 1.0 / Math.pow(2.0, i);
        }
        this.paraSpace3 = new double[8];
        for (i = 0; i < this.paraSpace3.length; ++i) {
            this.paraSpace3[i] = (double)i / 10.0;
        }
        this.paraSpace4 = new boolean[2];
        this.paraSpace4[0] = true;
        this.paraSpace4[1] = false;
        System.out.println("Number of parameters for each =" + this.paraSpace1.length * this.paraSpace2.length * this.paraSpace3.length * this.paraSpace4.length);
    }

    public void tuneMLP(Instances train) throws Exception {
        boolean bestDecay;
        String bestNumNodes;
        ClassifierResults tempResults;
        int folds = MAX_FOLDS;
        if (folds > train.numInstances()) {
            folds = train.numInstances();
        }
        double minErr = 1.0;
        this.setSeed(this.rng.nextInt());
        Instances trainCopy = new Instances(train);
        CrossValidator cv = new CrossValidator();
        cv.setSeed(this.seed);
        cv.setNumFolds(folds);
        cv.buildFolds(trainCopy);
        ArrayList<ResultsHolder> ties = new ArrayList<ResultsHolder>();
        int count = 0;
        OutFile temp = null;
        for (String p1 : this.paraSpace1) {
            for (double p2 : this.paraSpace2) {
                for (double p3 : this.paraSpace3) {
                    for (boolean p4 : this.paraSpace4) {
                        File f;
                        if (this.saveEachParaAcc && (f = new File(this.resultsPath + ++count + ".csv")).exists()) {
                            if (CollateResults.validateSingleFoldFile(this.resultsPath + count + ".csv")) continue;
                            System.out.println("Deleted file " + this.resultsPath + count + ".csv because size =" + f.length());
                        }
                        TunedTwoLayerMLP model = new TunedTwoLayerMLP();
                        model.tuneParameters(false);
                        model.findTrainAcc = false;
                        model.setHiddenLayers(p1);
                        model.setLearningRate(p2);
                        model.setMomentum(p3);
                        model.setDecay(p4);
                        tempResults = cv.crossValidateWithStats(model, trainCopy);
                        tempResults.setName("MLPPara" + count);
                        tempResults.setParas("HiddenNodes," + p1 + ",LearningRate," + p2 + ",Momentum," + p3 + ",Decay" + p4);
                        double e = 1.0 - tempResults.acc;
                        if (this.debug) {
                            System.out.println("HiddenNodes," + p1 + ",LearningRate," + p2 + ",Momentum," + p3 + ",Decay" + p4 + ", Acc = " + (1.0 - e));
                        }
                        this.accuracy.add(tempResults.acc);
                        if (this.saveEachParaAcc) {
                            temp = new OutFile(this.resultsPath + count + ".csv");
                            temp.writeLine(tempResults.writeResultsFileToString());
                            temp.closeFile();
                            continue;
                        }
                        if (e < minErr) {
                            minErr = e;
                            ties = new ArrayList();
                            ties.add(new ResultsHolder(p1, p2, p3, p4, tempResults));
                            continue;
                        }
                        if (e != minErr) continue;
                        ties.add(new ResultsHolder(p1, p2, p3, p4, tempResults));
                    }
                }
            }
        }
        minErr = 1.0;
        if (this.saveEachParaAcc) {
            int missing = 0;
            for (String p1 : this.paraSpace1) {
                double[] dArray = this.paraSpace2;
                int n = dArray.length;
                for (int i = 0; i < n; ++i) {
                    double p2 = dArray[i];
                    for (double p3 : this.paraSpace3) {
                        for (boolean p4 : this.paraSpace4) {
                            File f = new File(this.resultsPath + count + ".csv");
                            if (f.exists() && f.length() > 0L) continue;
                            ++missing;
                        }
                    }
                }
            }
            if (missing == 0) {
                this.combinedBuildTime = 0L;
                count = 0;
                for (String p1 : this.paraSpace1) {
                    for (double p2 : this.paraSpace2) {
                        for (double p3 : this.paraSpace3) {
                            for (boolean p4 : this.paraSpace4) {
                                tempResults = new ClassifierResults();
                                tempResults.loadFromFile(this.resultsPath + ++count + ".csv");
                                this.combinedBuildTime += tempResults.buildTime;
                                double e = 1.0 - tempResults.acc;
                                if (e < minErr) {
                                    minErr = e;
                                    ties = new ArrayList();
                                    ties.add(new ResultsHolder(p1, p2, p3, p4, tempResults));
                                    continue;
                                }
                                if (e != minErr) continue;
                                ties.add(new ResultsHolder(p1, p2, p3, p4, tempResults));
                            }
                        }
                    }
                }
                ResultsHolder best = (ResultsHolder)ties.get(this.rng.nextInt(ties.size()));
                bestNumNodes = best.nodes;
                double bestLearningRate = best.lRate;
                double bestMomentum = best.mRate;
                bestDecay = best.decay;
                this.setHiddenLayers(bestNumNodes);
                this.setLearningRate(bestLearningRate);
                this.setMomentum(bestMomentum);
                this.setDecay(bestDecay);
                this.res = best.res;
                count = 1;
                for (String p1 : this.paraSpace1) {
                    for (double p2 : this.paraSpace2) {
                        for (double p3 : this.paraSpace3) {
                            for (boolean p4 : this.paraSpace4) {
                                File f = new File(this.resultsPath + count + ".csv");
                                boolean deleted = f.delete();
                                if (!deleted) {
                                    System.out.println("DELETE FAILED " + this.resultsPath + count + ".csv");
                                    f.setReadable(true);
                                    f.setWritable(true);
                                    deleted = f.delete();
                                    if (!deleted) {
                                        System.out.println("\t DELETE FAILED AGAIN" + this.resultsPath + count + ".csv");
                                    }
                                }
                                ++count;
                            }
                        }
                    }
                }
            } else {
                System.out.println(this.resultsPath + " error: missing  =" + missing + " parameter values");
            }
        } else {
            ResultsHolder best = (ResultsHolder)ties.get(this.rng.nextInt(ties.size()));
            bestNumNodes = best.nodes;
            double bestLearningRate = best.lRate;
            double bestMomentum = best.mRate;
            bestDecay = best.decay;
            this.setHiddenLayers(bestNumNodes);
            this.setLearningRate(bestLearningRate);
            this.setMomentum(bestMomentum);
            this.setDecay(bestDecay);
            this.res = best.res;
        }
    }

    @Override
    public void buildClassifier(Instances data) throws Exception {
        long startTime = System.currentTimeMillis();
        int folds = MAX_FOLDS;
        if (folds > data.numInstances()) {
            folds = data.numInstances();
        }
        this.getCapabilities().testWithFail(data);
        data = new Instances(data);
        data.deleteWithMissingClass();
        super.setSeed(this.seed);
        if (this.tuneParameters) {
            if (this.paraSpace1 == null) {
                this.setStandardParaSearchSpace(data.numAttributes() - 1);
            }
            this.tuneMLP(data);
        } else if (this.findTrainAcc) {
            MultilayerPerceptron t = new MultilayerPerceptron();
            t.setHiddenLayers(this.getHiddenLayers());
            t.setLearningRate(this.getLearningRate());
            t.setMomentum(this.getMomentum());
            t.setDecay(this.getDecay());
            CrossValidator cv = new CrossValidator();
            cv.setSeed(this.seed);
            cv.setNumFolds(folds);
            cv.buildFolds(data);
            this.res = cv.crossValidateWithStats(t, data);
        }
        super.buildClassifier(data);
        this.res.buildTime = System.currentTimeMillis() - startTime;
        if (this.trainPath != "") {
            OutFile f = new OutFile(this.trainPath);
            f.writeLine(data.relationName() + ",TunedMLP,Train");
            f.writeLine(this.getParameters());
            f.writeLine(this.res.acc + "");
            f.writeString(this.res.writeInstancePredictions());
        }
    }

    public static void main(String[] args) {
        TunedTwoLayerMLP t = new TunedTwoLayerMLP();
        t.debug = true;
        for (int i = 1; i <= 1000; ++i) {
            t.setParametersFromIndex(i);
        }
        System.exit(0);
        DecimalFormat df = new DecimalFormat("##.###");
        try {
            String dset = "balloons";
            Instances all = ClassifierTools.loadData("C:\\Users\\ajb\\Dropbox\\UCI Problems\\" + dset + "\\" + dset);
            Instances[] split = InstanceTools.resampleInstances(all, 1L, 0.5);
            TunedTwoLayerMLP rf = new TunedTwoLayerMLP();
            rf.debug(true);
            rf.tuneParameters(true);
            rf.buildClassifier(split[0]);
        }
        catch (Exception e) {
            System.out.println("Exception " + e);
            e.printStackTrace();
            System.exit(0);
        }
    }

    static class ResultsHolder {
        String nodes;
        double lRate;
        double mRate;
        boolean decay;
        ClassifierResults res;

        ResultsHolder(String a, double b, double c, boolean d, ClassifierResults r) {
            this.nodes = a;
            this.lRate = b;
            this.mRate = 0.0;
            this.decay = d;
            this.res = r;
        }
    }
}

