/*
 * Decompiled with CFR 0.152.
 */
package applications;

import development.DataSets;
import fileIO.OutFile;
import java.io.File;
import java.text.DecimalFormat;
import timeseriesweka.classifiers.BOSS;
import timeseriesweka.classifiers.BagOfPatterns;
import timeseriesweka.classifiers.DD_DTW;
import timeseriesweka.classifiers.DTD_C;
import timeseriesweka.classifiers.ElasticEnsemble;
import timeseriesweka.classifiers.FastShapelets;
import timeseriesweka.classifiers.FlatCote;
import timeseriesweka.classifiers.HiveCote;
import timeseriesweka.classifiers.LPS;
import timeseriesweka.classifiers.LearnShapelets;
import timeseriesweka.classifiers.NN_CID;
import timeseriesweka.classifiers.ParameterSplittable;
import timeseriesweka.classifiers.RISE;
import timeseriesweka.classifiers.SAXVSM;
import timeseriesweka.classifiers.ST_HESCA;
import timeseriesweka.classifiers.TSBF;
import timeseriesweka.classifiers.TSF;
import timeseriesweka.classifiers.ensembles.SaveableEnsemble;
import timeseriesweka.classifiers.ensembles.elastic_ensemble.DTW1NN;
import timeseriesweka.classifiers.ensembles.elastic_ensemble.ED1NN;
import timeseriesweka.classifiers.ensembles.elastic_ensemble.MSM1NN;
import timeseriesweka.classifiers.ensembles.elastic_ensemble.WDTW1NN;
import utilities.ClassifierResults;
import utilities.ClassifierTools;
import utilities.CrossValidator;
import utilities.InstanceTools;
import utilities.SaveParameterInfo;
import utilities.TrainAccuracyEstimate;
import vector_classifiers.CAWPE;
import vector_classifiers.RotationForestLimitedAttributes;
import vector_classifiers.SaveEachParameter;
import vector_classifiers.TunedRandomForest;
import vector_classifiers.TunedRotationForest;
import vector_classifiers.TunedSVM;
import weka.classifiers.AbstractClassifier;
import weka.classifiers.Classifier;
import weka.classifiers.bayes.BayesNet;
import weka.classifiers.bayes.NaiveBayes;
import weka.classifiers.functions.Logistic;
import weka.classifiers.functions.MultilayerPerceptron;
import weka.classifiers.functions.SMO;
import weka.classifiers.functions.supportVector.PolyKernel;
import weka.classifiers.meta.RotationForest;
import weka.classifiers.trees.J48;
import weka.classifiers.trees.RandomForest;
import weka.core.Instances;

public class Football {
    public static int folds = 30;
    static boolean debug = true;
    static boolean checkpoint = false;
    static boolean generateTrainFiles = true;
    static Integer parameterNum = 0;

    public static Classifier setClassifier(String classifier, int fold) {
        AbstractClassifier c = null;
        TunedSVM svm = null;
        switch (classifier) {
            case "ED": {
                c = new ED1NN();
                break;
            }
            case "C45": {
                c = new J48();
                break;
            }
            case "NB": {
                c = new NaiveBayes();
                break;
            }
            case "SVML": {
                c = new SMO();
                PolyKernel p = new PolyKernel();
                p.setExponent(1.0);
                ((SMO)c).setKernel(p);
                break;
            }
            case "SVMQ": {
                c = new SMO();
                PolyKernel p2 = new PolyKernel();
                p2.setExponent(2.0);
                ((SMO)c).setKernel(p2);
                break;
            }
            case "BN": {
                c = new BayesNet();
                break;
            }
            case "MLP": {
                c = new MultilayerPerceptron();
                break;
            }
            case "RandFOOB": {
                c = new TunedRandomForest();
                ((RandomForest)c).setNumTrees(500);
                ((TunedRandomForest)c).tuneParameters(false);
                ((TunedRandomForest)c).setCrossValidate(false);
                ((TunedRandomForest)c).setSeed(fold);
                break;
            }
            case "RandF": {
                c = new TunedRandomForest();
                ((RandomForest)c).setNumTrees(500);
                ((TunedRandomForest)c).tuneParameters(false);
                ((TunedRandomForest)c).setCrossValidate(true);
                ((TunedRandomForest)c).setSeed(fold);
                break;
            }
            case "RotF": {
                c = new TunedRotationForest();
                ((RotationForest)c).setNumIterations(200);
                ((TunedRotationForest)c).tuneParameters(false);
                ((TunedRotationForest)c).setSeed(fold);
                break;
            }
            case "TunedRandF": {
                c = new TunedRandomForest();
                ((TunedRandomForest)c).tuneParameters(true);
                ((TunedRandomForest)c).setCrossValidate(true);
                ((TunedRandomForest)c).setSeed(fold);
                break;
            }
            case "TunedRandFOOB": {
                c = new TunedRandomForest();
                ((TunedRandomForest)c).tuneParameters(true);
                ((TunedRandomForest)c).setCrossValidate(false);
                ((TunedRotationForest)c).setSeed(fold);
                break;
            }
            case "TunedRotF": {
                c = new TunedRotationForest();
                ((TunedRotationForest)c).tuneParameters(true);
                ((TunedRotationForest)c).setSeed(fold);
                break;
            }
            case "TunedSVMRBF": {
                svm = new TunedSVM();
                svm.setKernelType(TunedSVM.KernelType.RBF);
                svm.optimiseParas(true);
                svm.optimiseKernel(false);
                svm.setBuildLogisticModels(true);
                svm.setSeed(fold);
                c = svm;
                break;
            }
            case "TunedSVMQuad": {
                svm = new TunedSVM();
                svm.setKernelType(TunedSVM.KernelType.QUADRATIC);
                svm.optimiseParas(true);
                svm.optimiseKernel(false);
                svm.setBuildLogisticModels(true);
                svm.setSeed(fold);
                c = svm;
                break;
            }
            case "TunedSVMLinear": {
                svm = new TunedSVM();
                svm.setKernelType(TunedSVM.KernelType.LINEAR);
                svm.optimiseParas(true);
                svm.optimiseKernel(false);
                svm.setBuildLogisticModels(true);
                svm.setSeed(fold);
                c = svm;
                break;
            }
            case "TunedSVMKernel": {
                svm = new TunedSVM();
                svm.optimiseParas(true);
                svm.optimiseKernel(true);
                svm.setBuildLogisticModels(true);
                svm.setSeed(fold);
                c = svm;
                break;
            }
            case "RandomRotationForest1": {
                c = new RotationForestLimitedAttributes();
                ((RotationForestLimitedAttributes)c).setNumIterations(200);
                ((RotationForestLimitedAttributes)c).setMaxNumAttributes(100);
                break;
            }
            case "Logistic": {
                c = new Logistic();
                break;
            }
            case "HESCA": {
                c = new CAWPE();
                break;
            }
            case "EE": 
            case "ElasticEnsemble": {
                c = new ElasticEnsemble();
                break;
            }
            case "DTW": {
                c = new DTW1NN();
                ((DTW1NN)c).setWindow(1.0);
                break;
            }
            case "DTWCV": {
                c = new DTW1NN();
                break;
            }
            case "DD_DTW": {
                c = new DD_DTW();
                break;
            }
            case "DTD_C": {
                c = new DTD_C();
                break;
            }
            case "CID_DTW": {
                c = new NN_CID();
                ((NN_CID)c).useDTW();
                break;
            }
            case "MSM": {
                c = new MSM1NN();
                break;
            }
            case "TWE": {
                c = new MSM1NN();
                break;
            }
            case "WDTW": {
                c = new WDTW1NN();
                break;
            }
            case "LearnShapelets": 
            case "LS": {
                c = new LearnShapelets();
                break;
            }
            case "FastShapelets": 
            case "FS": {
                c = new FastShapelets();
                break;
            }
            case "ShapeletTransform": 
            case "ST": 
            case "ST_Ensemble": {
                c = new ST_HESCA();
                ((ST_HESCA)c).setOneDayLimit();
                break;
            }
            case "TSF": {
                c = new TSF();
                break;
            }
            case "RISE": {
                c = new RISE();
                break;
            }
            case "TSBF": {
                c = new TSBF();
                break;
            }
            case "BOP": 
            case "BoP": 
            case "BagOfPatterns": {
                c = new BagOfPatterns();
                break;
            }
            case "BOSS": 
            case "BOSSEnsemble": {
                c = new BOSS();
                break;
            }
            case "SAXVSM": 
            case "SAX": {
                c = new SAXVSM();
                break;
            }
            case "LPS": {
                c = new LPS();
                break;
            }
            case "FlatCOTE": {
                c = new FlatCote();
                break;
            }
            case "HiveCOTE": {
                c = new HiveCote();
                break;
            }
            default: {
                System.out.println("UNKNOWN CLASSIFIER " + classifier);
                System.exit(0);
            }
        }
        return c;
    }

    public static void singleClassifierAndFoldTrainTestSplit(String[] args) throws Exception {
        String predictions;
        String classifier = args[0];
        String problem = args[1];
        int fold = Integer.parseInt(args[2]) - 1;
        Classifier c = Football.setClassifier(classifier, fold);
        Instances train = ClassifierTools.loadData(DataSets.problemPath + problem + "/" + problem + "_TRAIN");
        Instances test = ClassifierTools.loadData(DataSets.problemPath + problem + "/" + problem + "_TEST");
        File f = new File(DataSets.resultsPath + classifier);
        if (!f.exists()) {
            f.mkdir();
        }
        if (!(f = new File(predictions = DataSets.resultsPath + classifier + "/Predictions")).exists()) {
            f.mkdir();
        }
        if (!(f = new File(predictions = predictions + "/" + problem)).exists()) {
            f.mkdir();
        }
        if (!(f = new File(predictions + "/testFold" + fold + ".csv")).exists() || f.length() == 0L) {
            if (parameterNum > 0 && c instanceof ParameterSplittable) {
                checkpoint = false;
                f = new File(predictions + "/fold" + fold + "_" + parameterNum + ".csv");
                if (f.exists() && f.length() > 0L) {
                    return;
                }
            } else if (generateTrainFiles) {
                if (c instanceof TrainAccuracyEstimate) {
                    ((TrainAccuracyEstimate)((Object)c)).writeCVTrainToFile(predictions + "/trainFold" + fold + ".csv");
                } else {
                    int numFolds = train.numInstances() >= 10 ? 10 : train.numInstances();
                    CrossValidator cv = new CrossValidator();
                    cv.setSeed(fold);
                    cv.setNumFolds(numFolds);
                    ClassifierResults res = cv.crossValidateWithStats(c, train);
                    OutFile of = new OutFile(predictions + "/trainFold" + fold + ".csv");
                    of.writeLine(train.relationName() + "," + c.getClass().getName() + ",train");
                    if (c instanceof SaveParameterInfo) {
                        of.writeLine(((SaveParameterInfo)((Object)c)).getParameters());
                    } else {
                        of.writeLine("No Parameter Info");
                    }
                    of.writeLine(res.acc + "");
                    if (res.numInstances() > 0) {
                        double[] trueClassVals = res.getTrueClassVals();
                        double[] predClassVals = res.getPredClassVals();
                        DecimalFormat df = new DecimalFormat("###.###");
                        for (int i = 0; i < train.numInstances(); ++i) {
                            double[] distForInst;
                            if (train.instance(i).classValue() != trueClassVals[i]) {
                                throw new Exception("ERROR in TSF cross validation, class mismatch!");
                            }
                            of.writeString((int)trueClassVals[i] + "," + (int)predClassVals[i] + ",");
                            for (double d : distForInst = res.getDistributionForInstance(i)) {
                                of.writeString("," + df.format(d));
                            }
                            if (i >= train.numInstances() - 1) continue;
                            of.writeString("\n");
                        }
                    }
                }
            }
            double acc = Football.singleClassifierAndFoldTrainTestSplit(train, test, c, fold, predictions);
            System.out.println(classifier + "," + problem + "," + fold + "," + acc);
        }
    }

    public static double singleClassifierAndFoldTrainTestSplit(Instances train, Instances test, Classifier c, int fold, String resultsPath) {
        Instances[] data = InstanceTools.resampleTrainAndTestInstances(train, test, fold);
        double acc = 0.0;
        String testFoldPath = "/testFold" + fold + ".csv";
        if (parameterNum > 0 && c instanceof ParameterSplittable) {
            checkpoint = false;
            ((ParameterSplittable)((Object)c)).setParametersFromIndex(parameterNum);
            testFoldPath = "/fold" + fold + "_" + parameterNum + ".csv";
        } else {
            if (c instanceof SaveableEnsemble) {
                ((SaveableEnsemble)((Object)c)).saveResults(resultsPath + "/internalCV_" + fold + ".csv", resultsPath + "/internalTestPreds_" + fold + ".csv");
            }
            if (checkpoint && c instanceof SaveEachParameter) {
                ((SaveEachParameter)((Object)c)).setPathToSaveParameters(resultsPath + "/fold" + fold + "_");
            }
        }
        try {
            c.buildClassifier(data[0]);
            if (debug && c instanceof RandomForest) {
                System.out.println(" Number of features in MAIN=" + ((RandomForest)c).getNumFeatures());
            }
            StringBuilder str = new StringBuilder();
            DecimalFormat df = new DecimalFormat("##.######");
            for (int j = 0; j < data[1].numInstances(); ++j) {
                int act = (int)data[1].instance(j).classValue();
                data[1].instance(j).setClassMissing();
                double[] probs = c.distributionForInstance(data[1].instance(j));
                int pred = 0;
                for (int i = 1; i < probs.length; ++i) {
                    if (!(probs[i] > probs[pred])) continue;
                    pred = i;
                }
                if (act == pred) {
                    acc += 1.0;
                }
                str.append(act);
                str.append(",");
                str.append(pred);
                str.append(",");
                for (double d : probs) {
                    str.append(",");
                    str.append(df.format(d));
                }
                if (j >= data[1].numInstances() - 1) continue;
                str.append("\n");
            }
            acc /= (double)data[1].numInstances();
            OutFile p = new OutFile(resultsPath + testFoldPath);
            p.writeLine(train.relationName() + "," + c.getClass().getName() + ",test");
            if (c instanceof SaveParameterInfo) {
                p.writeLine(((SaveParameterInfo)((Object)c)).getParameters());
            } else {
                p.writeLine("No parameter info");
            }
            p.writeLine(acc + "");
            p.writeString(str.toString());
        }
        catch (Exception e) {
            System.out.println(" Error =" + e + " in method simpleExperiment" + e);
            e.printStackTrace();
            System.out.println(" TRAIN " + train.relationName() + " has " + train.numAttributes() + " attributes and " + train.numInstances() + " instances");
            System.out.println(" TEST " + test.relationName() + " has " + test.numAttributes() + " attributes" + test.numInstances() + " instances");
            System.exit(0);
        }
        return acc;
    }

    public static void main(String[] args) throws Exception {
        double[] cd;
        Instances data = ClassifierTools.loadData("C:\\Users\\ajb\\Dropbox\\Temp\\FootballPlayer");
        for (double d : cd = InstanceTools.findClassDistributions(data)) {
            System.out.println(d);
        }
        RotationForest rf = new RotationForest();
        rf.setNumIterations(500);
        double[][] a = ClassifierTools.crossValidationWithStats(rf, data, 10);
        System.out.println("ROTF ACC = " + a[0][0]);
        TunedSVM svm = new TunedSVM();
        svm.setKernelType(TunedSVM.KernelType.QUADRATIC);
        svm.optimiseParas(true);
        svm.optimiseKernel(false);
        svm.setBuildLogisticModels(true);
        a = ClassifierTools.crossValidationWithStats(svm, data, 10);
        System.out.println("SVM ACC = " + a[0][0]);
    }
}

