/*
 * Decompiled with CFR 0.152.
 */
package utilities;

import fileIO.OutFile;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Random;
import statistics.distributions.NormalDistribution;
import utilities.ClassifierResults;
import utilities.GenericTools;
import weka.classifiers.AbstractClassifier;
import weka.classifiers.Classifier;
import weka.classifiers.bayes.NaiveBayes;
import weka.classifiers.evaluation.EvaluationUtils;
import weka.classifiers.evaluation.NominalPrediction;
import weka.classifiers.functions.SMO;
import weka.classifiers.functions.supportVector.PolyKernel;
import weka.classifiers.lazy.IBk;
import weka.classifiers.lazy.kNN;
import weka.classifiers.meta.RotationForest;
import weka.classifiers.trees.J48;
import weka.classifiers.trees.RandomForest;
import weka.core.Attribute;
import weka.core.DenseInstance;
import weka.core.FastVector;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.converters.ArffSaver;
import weka.filters.supervised.attribute.NominalToBinary;
import weka.filters.unsupervised.attribute.ReplaceMissingValues;

public class ClassifierTools {
    public static Instances loadData(String fullPath) {
        if (fullPath.substring(fullPath.length() - 5, fullPath.length()).equalsIgnoreCase(".ARFF")) {
            fullPath = fullPath.substring(0, fullPath.length() - 5);
        }
        Instances d = null;
        try {
            FileReader r = new FileReader(fullPath + ".arff");
            d = new Instances(r);
            d.setClassIndex(d.numAttributes() - 1);
        }
        catch (IOException e) {
            System.out.println("Unable to load data on path " + fullPath + " Exception thrown =" + e);
            System.exit(0);
        }
        return d;
    }

    public static Instances loadDataThrowable(String fullPath) throws Exception {
        if (fullPath.substring(fullPath.length() - 5, fullPath.length()).equalsIgnoreCase(".ARFF")) {
            fullPath = fullPath.substring(0, fullPath.length() - 5);
        }
        Instances d = null;
        FileReader r = new FileReader(fullPath + ".arff");
        d = new Instances(r);
        d.setClassIndex(d.numAttributes() - 1);
        return d;
    }

    public static Instances loadData(File file) throws IOException {
        Instances inst = new Instances(new FileReader(file));
        inst.setClassIndex(inst.numAttributes() - 1);
        return inst;
    }

    public static void saveDataset(Instances dataSet, String fileName) {
        try {
            ArffSaver saver = new ArffSaver();
            saver.setInstances(dataSet);
            saver.setFile(new File(fileName + ".arff"));
            saver.writeBatch();
        }
        catch (IOException ex) {
            System.out.println("Error saving transformed dataset" + ex);
        }
    }

    public static double accuracy(Instances test, Classifier c) {
        double a = 0.0;
        int size = test.numInstances();
        for (int i = 0; i < size; ++i) {
            Instance d = test.instance(i);
            try {
                double predictedClass = c.classifyInstance(d);
                double trueClass = d.classValue();
                if (trueClass != predictedClass) continue;
                a += 1.0;
                continue;
            }
            catch (Exception e) {
                System.out.println(" Error with instance " + i + " with Classifier " + c.getClass().getName() + " Exception =" + e);
                e.printStackTrace();
                System.exit(0);
            }
        }
        return a / (double)size;
    }

    public static Classifier[] setDefaultSingleClassifiers(ArrayList<String> names) {
        ArrayList<AbstractClassifier> sc2 = new ArrayList<AbstractClassifier>();
        sc2.add(new kNN(1));
        names.add("NN");
        sc2.add(new NaiveBayes());
        names.add("NB");
        sc2.add(new J48());
        names.add("C45");
        AbstractClassifier c = new SMO();
        PolyKernel kernel = new PolyKernel();
        kernel.setExponent(1.0);
        c.setKernel(kernel);
        sc2.add(c);
        names.add("SVML");
        c = new SMO();
        kernel = new PolyKernel();
        kernel.setExponent(2.0);
        c.setKernel(kernel);
        sc2.add(c);
        names.add("SVMQ");
        c = new RandomForest();
        ((RandomForest)c).setNumTrees(100);
        sc2.add(c);
        names.add("RandF100");
        c = new RotationForest();
        sc2.add(c);
        names.add("RotF30");
        Classifier[] sc = new Classifier[sc2.size()];
        for (int i = 0; i < sc.length; ++i) {
            sc[i] = (Classifier)sc2.get(i);
        }
        return sc;
    }

    public static double[][] predict(Instances trainData, Instances testData, Classifier c) {
        double[][] results = new double[testData.numInstances()][];
        try {
            c.buildClassifier(trainData);
            for (int i = 0; i < testData.numInstances(); ++i) {
                results[i] = c.distributionForInstance(testData.instance(i));
            }
        }
        catch (Exception e) {
            System.out.println(" Error in manual cross val");
        }
        return results;
    }

    public static double[][] crossValidation(Classifier c, Instances allData, int m) {
        double[][] preds = new double[2][allData.numInstances()];
        try {
            EvaluationUtils evalU = new EvaluationUtils();
            evalU.setSeed(10);
            FastVector f = evalU.getCVPredictions(c, allData, m);
            Object[] p = f.toArray();
            for (int i = 0; i < p.length; ++i) {
                NominalPrediction nom = (NominalPrediction)p[i];
                preds[1][i] = nom.predicted();
                preds[0][i] = nom.actual();
            }
        }
        catch (Exception e) {
            System.out.println(" Error =" + e + " in method Cross Validate Experiment");
            e.printStackTrace();
            System.out.println(allData.relationName());
            System.exit(0);
        }
        return preds;
    }

    public static double[][] crossValidationWithStats(Classifier c, Instances allData, int m) {
        double[][] preds = new double[2][allData.numInstances() + 1];
        int foldSize = allData.numInstances() / m;
        double acc = 0.0;
        double sum = 0.0;
        double sumsq = 0.0;
        try {
            EvaluationUtils evalU = new EvaluationUtils();
            FastVector f = evalU.getCVPredictions(c, allData, m);
            Object[] p = f.toArray();
            for (int i = 0; i < p.length; ++i) {
                NominalPrediction nom = (NominalPrediction)p[i];
                preds[1][i + 1] = nom.predicted();
                preds[0][i + 1] = nom.actual();
                if (preds[0][i + 1] == preds[1][i + 1]) {
                    double[] dArray = preds[0];
                    dArray[0] = dArray[0] + 1.0;
                    acc += 1.0;
                }
                if (i <= 0 || i % foldSize != 0) continue;
                sumsq += acc / (double)foldSize * (acc / (double)foldSize);
                sum += acc / (double)foldSize;
                acc = 0.0;
            }
            preds[0][0] = preds[0][0] / (double)p.length;
            preds[1][0] = (sumsq - sum * sum / (double)m) / (double)m;
            preds[1][0] = Math.sqrt(preds[1][0]);
        }
        catch (Exception e) {
            System.out.println(" Error =" + e + " in method Cross Validate Experiment");
            e.printStackTrace();
            System.out.println(allData.relationName());
            System.exit(0);
        }
        return preds;
    }

    public static double stratifiedCrossValidation(Instances data, Classifier c, int folds, int seed) {
        Random rand = new Random(seed);
        Instances randData = new Instances(data);
        randData.randomize(rand);
        randData.stratify(folds);
        int correct = 0;
        int total = data.numInstances();
        for (int n = 0; n < folds; ++n) {
            Instances train = randData.trainCV(folds, n);
            Instances test = randData.testCV(folds, n);
            try {
                c.buildClassifier(train);
                for (Instance ins : test) {
                    int pred = (int)c.classifyInstance(ins);
                    if ((double)pred != ins.classValue()) continue;
                    ++correct;
                }
                continue;
            }
            catch (Exception e) {
                System.err.println("ERROR BUILDING FOLD " + n + " for data set " + data.relationName());
                e.printStackTrace();
                System.exit(1);
            }
        }
        return (double)correct / (double)total;
    }

    public static double[][] performManualCrossValidation(Instances data, Classifier c, int numFolds) {
        double[][] results = new double[data.numInstances()][data.numClasses()];
        int interval = data.numInstances() / numFolds;
        int start = 0;
        int end = interval;
        int testCount = 0;
        try {
            for (int f = 0; f < numFolds; ++f) {
                int i;
                Instances train = new Instances(data, 0);
                Instances test = new Instances(data, 0);
                for (i = 0; i < data.numInstances(); ++i) {
                    if (i >= start && i < end) {
                        test.add(data.instance(i));
                        continue;
                    }
                    train.add(data.instance(i));
                }
                c.buildClassifier(data);
                for (i = 0; i < interval; ++i) {
                    results[testCount] = c.distributionForInstance(test.instance(i));
                    ++testCount;
                }
                start = end;
                end += interval;
            }
        }
        catch (Exception e) {
            System.out.println(" Error in manual cross val");
        }
        return results;
    }

    public static void makePredictions(Classifier model, Instances data, String path) {
        OutFile f1 = new OutFile(path + ".csv");
        try {
            for (int i = 0; i < data.numInstances(); ++i) {
                Instance t = data.instance(i);
                double actual = t.classValue();
                double pred = model.classifyInstance(t);
                f1.writeLine(i + "," + actual + "," + pred);
            }
        }
        catch (Exception e) {
            System.out.println("Exception in makePredictions" + e);
        }
    }

    public static Classifier[] setSingleClassifiers(ArrayList<String> names) {
        ArrayList<AbstractClassifier> sc2 = new ArrayList<AbstractClassifier>();
        IBk k = new IBk(50);
        k.setCrossValidate(true);
        sc2.add(k);
        names.add("kNN");
        sc2.add(new NaiveBayes());
        names.add("NB");
        sc2.add(new J48());
        names.add("C45");
        AbstractClassifier c = new SMO();
        PolyKernel kernel = new PolyKernel();
        kernel.setExponent(1.0);
        c.setKernel(kernel);
        sc2.add(c);
        names.add("SVML");
        c = new SMO();
        kernel = new PolyKernel();
        kernel.setExponent(2.0);
        c.setKernel(kernel);
        sc2.add(c);
        names.add("SVMQ");
        c = new RandomForest();
        ((RandomForest)c).setNumTrees(100);
        sc2.add(c);
        names.add("RandF100");
        c = new RotationForest();
        sc2.add(c);
        names.add("RotF30");
        Classifier[] sc = new Classifier[sc2.size()];
        for (int i = 0; i < sc.length; ++i) {
            sc[i] = (Classifier)sc2.get(i);
        }
        return sc;
    }

    public static double singleTrainTestSplitAccuracy(Classifier c, Instances train, Instances test) {
        double acc = 0.0;
        try {
            c.buildClassifier(train);
            int correct = 0;
            for (Instance ins : test) {
                int pred = (int)c.classifyInstance(ins);
                if (pred != (int)ins.classValue()) continue;
                ++correct;
            }
            acc = (double)correct / (double)test.numInstances();
        }
        catch (Exception e) {
            System.out.println(" Error =" + e + " in method singleTrainTestSplitAccuracy" + e);
            e.printStackTrace();
            System.exit(0);
        }
        return acc;
    }

    public static double[][] crossValidate(Classifier c, Instances data, int numFolds) {
        double[][] results = new double[data.numInstances()][data.numClasses()];
        int interval = data.numInstances() / numFolds;
        int start = 0;
        int end = interval;
        int testCount = 0;
        try {
            for (int f = 0; f < numFolds; ++f) {
                int i;
                if (f == numFolds - 1) {
                    end = data.numInstances();
                }
                Instances train = new Instances(data, 0);
                Instances test = new Instances(data, 0);
                for (i = 0; i < data.numInstances(); ++i) {
                    if (i >= start && i < end) {
                        test.add(data.instance(i));
                        continue;
                    }
                    train.add(data.instance(i));
                }
                c.buildClassifier(train);
                for (i = 0; i < test.numInstances(); ++i) {
                    results[testCount] = c.distributionForInstance(test.instance(i));
                    ++testCount;
                }
                start = end;
                end += interval;
            }
        }
        catch (Exception e) {
            System.out.println(" Error in manual cross val");
        }
        return results;
    }

    public static ClassifierResults constructClassifierResults(Classifier classifier, Instances test) throws Exception {
        double[] preds = new double[test.numInstances()];
        double[][] distForInstances = new double[test.numInstances()][];
        double correct = 0.0;
        for (int i = 0; i < test.numInstances(); ++i) {
            Instance test1 = test.get(i);
            distForInstances[i] = classifier.distributionForInstance(test1);
            preds[i] = GenericTools.indexOfMax(distForInstances[i]);
            if (preds[i] != test1.classValue()) continue;
            correct += 1.0;
        }
        double accuracy = correct / (double)test.numInstances();
        double[] classVals = test.attributeToDoubleArray(test.classIndex());
        ClassifierResults results = new ClassifierResults(accuracy, classVals, preds, distForInstances, test.numClasses());
        results.setNumInstances(test.numInstances());
        results.setNumClasses(test.numClasses());
        return results;
    }

    public static ResultsStats[] evalClassifiers(Instances test, Instances train, int folds, Classifier[] sc) throws Exception {
        int nosClassifiers = sc.length;
        ResultsStats[] mean = new ResultsStats[nosClassifiers];
        int seed = 100;
        for (int i = 0; i < nosClassifiers; ++i) {
            if (folds > 1) {
                Instances full = new Instances(train);
                for (int j = 0; j < test.numInstances(); ++j) {
                    full.add(test.instance(j));
                }
                Random rand = new Random(seed);
                full.randomize(rand);
                double[][] preds = ClassifierTools.crossValidation(sc[i], full, folds);
                mean[i] = ResultsStats.find(preds, full.numInstances());
                continue;
            }
            sc[i].buildClassifier(train);
            mean[i] = new ResultsStats();
            mean[i].accuracy = ClassifierTools.accuracy(test, sc[i]);
        }
        return mean;
    }

    public static Instances estimateMissing(Instances data) {
        ReplaceMissingValues nb = new ReplaceMissingValues();
        Instances nd = null;
        try {
            int i;
            nb.setInputFormat(data);
            int n = data.numInstances();
            for (i = 0; i < n; ++i) {
                nb.input(data.instance(i));
            }
            System.out.println(" Instances input");
            System.out.println(" Output format retrieved");
            if (nb.batchFinished()) {
                System.out.println(" batch finished ");
            }
            nd = nb.getOutputFormat();
            for (i = 0; i < n; ++i) {
                Instance temp = nb.output();
                nd.add(temp);
            }
        }
        catch (Exception e) {
            System.out.println("Error in estimateMissing  = " + e.toString());
            nd = data;
            System.exit(0);
        }
        return nd;
    }

    public static Instances makeBinary(Instances data) {
        Instances nd;
        NominalToBinary nb = new NominalToBinary();
        try {
            int i;
            nb.setInputFormat(data);
            int n = data.numInstances();
            for (i = 0; i < n; ++i) {
                nb.input(data.instance(i));
            }
            nd = nb.getOutputFormat();
            for (i = 0; i < n; ++i) {
                Instance temp = nb.output();
                nd.add(temp);
            }
        }
        catch (Exception e) {
            System.out.println("Error in NominalToBinary  = " + e.toString());
            nd = data;
            System.exit(0);
        }
        return nd;
    }

    public static Instances generateRandomProblem(int numAtts, int numCases, int numClasses) {
        String name = "Random" + numAtts + "_" + numCases + "_" + numClasses;
        ArrayList<Attribute> atts = new ArrayList<Attribute>(numAtts);
        for (int i = 0; i < numAtts; ++i) {
            Attribute at = new Attribute("Rand" + i);
            atts.add(at);
        }
        ArrayList<String> vals = new ArrayList<String>(numClasses);
        for (int i = 0; i < numClasses; ++i) {
            vals.add(i + "");
        }
        atts.add(new Attribute("Response", vals));
        NormalDistribution norm = new NormalDistribution(0.0, 1.0);
        Random rng = new Random();
        Instances data = new Instances(name, atts, numCases);
        data.setClassIndex(numAtts);
        for (int i = 0; i < numCases; ++i) {
            DenseInstance in = new DenseInstance(data.numAttributes());
            for (int j = 0; j < numAtts; ++j) {
                double v = norm.simulate();
                in.setValue(j, v);
            }
            double classV = rng.nextInt(numClasses);
            in.setValue(numAtts, classV);
            data.add(in);
        }
        return data;
    }

    public static class ResultsStats {
        public double accuracy;
        public double sd;
        public double min;
        public double max;

        public ResultsStats() {
            this.accuracy = 0.0;
            this.sd = 0.0;
        }

        public ResultsStats(double[][] preds, int folds) {
            this.findCVMeanSD(preds, folds);
        }

        public static ResultsStats find(double[][] preds, int folds) {
            ResultsStats f = new ResultsStats();
            f.findCVMeanSD(preds, folds);
            return f;
        }

        public void findCVMeanSD(double[][] preds, int folds) {
            int i;
            int j;
            double[] acc = new double[folds];
            int count = 0;
            int window = (preds[0].length - 1) / folds;
            window = preds[0].length / folds;
            for (int i2 = 0; i2 < folds - 1; ++i2) {
                acc[i2] = 0.0;
                for (j = 0; j < window; ++j) {
                    if (preds[0][count] == preds[1][count]) {
                        int n = i2;
                        acc[n] = acc[n] + 1.0;
                    }
                    ++count;
                }
            }
            int lastSize = preds[0].length - count;
            for (j = count; j < preds[0].length; ++j) {
                if (preds[0][count] == preds[1][count]) {
                    int n = folds - 1;
                    acc[n] = acc[n] + 1.0;
                }
                ++count;
            }
            this.accuracy = acc[0];
            this.min = acc[0];
            this.max = 0.0;
            for (i = 1; i < folds; ++i) {
                this.accuracy += acc[i];
                if (acc[i] < this.min) {
                    this.min = acc[i];
                }
                if (!(acc[i] > this.max)) continue;
                this.max = acc[i];
            }
            this.accuracy /= (double)preds[0].length;
            this.sd = 0.0;
            for (i = 0; i < folds - 1; ++i) {
                this.sd += (acc[i] / (double)window - this.accuracy) * (acc[i] / (double)window - this.accuracy);
            }
            this.sd += (acc[folds - 1] / (double)lastSize - this.accuracy) * (acc[folds - 1] / (double)lastSize - this.accuracy);
            this.sd /= (double)folds;
            this.sd = Math.sqrt(this.sd);
        }

        public String toString() {
            return "Accuracy = " + this.accuracy + " SD = " + this.sd + " Min = " + this.min + " Max = " + this.max;
        }
    }
}

