/*
 * Decompiled with CFR 0.152.
 */
package utilities;

import fileIO.InFile;
import fileIO.OutFile;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import utilities.DebugPrinting;
import utilities.GenericTools;
import utilities.InstanceTools;

public class ClassifierResults
implements DebugPrinting {
    public long buildTime;
    public long memory;
    private int numClasses;
    private int numInstances;
    private String name;
    private String paras;
    public double acc;
    public double balancedAcc;
    public double sensitivity;
    public double specificity;
    public double precision;
    public double recall;
    public double f1;
    public double mcc;
    public double nll;
    public double meanAUROC;
    public double stddev;
    public double[][] confusionMatrix;
    public double[] countPerClass;
    public static double NLL_PENALTY = -6.64;
    public ArrayList<Double> actualClassValues;
    public ArrayList<Double> predictedClassValues;
    public ArrayList<double[]> predictedClassProbabilities;
    private boolean finalised = false;
    private boolean allStatsFound = false;

    public ClassifierResults() {
        this.actualClassValues = new ArrayList();
        this.predictedClassValues = new ArrayList();
        this.predictedClassProbabilities = new ArrayList();
        this.finalised = false;
    }

    public ClassifierResults(String filePathAndName) throws FileNotFoundException {
        this.loadFromFile(filePathAndName);
    }

    public ClassifierResults(int numClasses) {
        this.actualClassValues = new ArrayList();
        this.predictedClassValues = new ArrayList();
        this.predictedClassProbabilities = new ArrayList();
        this.numClasses = numClasses;
        this.finalised = false;
    }

    public ClassifierResults(double cvacc, int numClasses) {
        this();
        this.acc = cvacc;
        this.numClasses = numClasses;
        this.finalised = false;
    }

    public ClassifierResults(double acc, double[] classVals, double[] preds, double[][] distsForInsts, int numClasses) {
        this();
        for (double d : preds) {
            this.predictedClassValues.add(d);
        }
        this.acc = acc;
        for (double[] d : distsForInsts) {
            this.predictedClassProbabilities.add(d);
        }
        this.numClasses = numClasses;
        for (double d : classVals) {
            this.actualClassValues.add(d);
        }
        this.confusionMatrix = this.buildConfusionMatrix();
        this.stddev = -1.0;
        this.finalised = true;
    }

    public ClassifierResults(double acc, double[] classVals, double[] preds, double[][] distsForInsts, double stddev, int numClasses) {
        this(acc, classVals, preds, distsForInsts, numClasses);
        this.stddev = stddev;
        this.finalised = true;
    }

    public int getNumClasses() {
        return this.numClasses;
    }

    public void setNumClasses(int numClasses) {
        this.numClasses = numClasses;
    }

    public int getNumInstances() {
        return this.numInstances;
    }

    public void setNumInstances(int numInstances) {
        this.numInstances = numInstances;
    }

    public String getName() {
        return this.name;
    }

    public void setName(String name) {
        this.name = name;
    }

    public String getParas() {
        return this.paras;
    }

    public void setParas(String paras) {
        this.paras = paras;
    }

    public void cleanPredictionInfo() {
        this.predictedClassProbabilities = null;
        this.predictedClassValues = null;
        this.actualClassValues = null;
    }

    private double[][] buildConfusionMatrix() {
        double[][] matrix = new double[this.numClasses][this.numClasses];
        for (int i = 0; i < this.predictedClassValues.size(); ++i) {
            double actual = this.actualClassValues.get(i);
            double predicted = this.predictedClassValues.get(i);
            double[] dArray = matrix[(int)actual];
            int n = (int)predicted;
            dArray[n] = dArray[n] + 1.0;
        }
        return matrix;
    }

    public void addAllResults(double[] classVals, double[] preds, double[][] distsForInsts, int numClasses) {
        this.actualClassValues = new ArrayList();
        this.predictedClassValues = new ArrayList();
        this.predictedClassProbabilities = new ArrayList();
        for (double d : preds) {
            this.predictedClassValues.add(d);
        }
        this.acc = this.acc;
        for (double[] d : distsForInsts) {
            this.predictedClassProbabilities.add(d);
        }
        this.numClasses = numClasses;
        for (double d : classVals) {
            this.actualClassValues.add(d);
        }
        this.confusionMatrix = this.buildConfusionMatrix();
        this.stddev = -1.0;
    }

    public void storeSingleResult(double[] dist) {
        double max = dist[0];
        double maxInd = 0.0;
        this.predictedClassProbabilities.add(dist);
        for (int i = 0; i < dist.length; ++i) {
            if (!(dist[i] > max)) continue;
            max = dist[i];
            maxInd = i;
        }
        this.predictedClassValues.add(maxInd);
    }

    public void storeSingleResult(double actual, double[] dist) {
        double max = dist[0];
        double maxInd = 0.0;
        this.predictedClassProbabilities.add(dist);
        for (int i = 0; i < dist.length; ++i) {
            if (!(dist[i] > max)) continue;
            max = dist[i];
            maxInd = i;
        }
        this.predictedClassValues.add(maxInd);
        this.actualClassValues.add(actual);
    }

    public void finaliseResults(double[] testClassVals) throws Exception {
        if (this.finalised) {
            this.printlnDebug("Results already finalised, skipping re-finalisation");
            return;
        }
        if (this.predictedClassProbabilities == null || this.predictedClassValues == null || this.predictedClassProbabilities.isEmpty() || this.predictedClassValues.isEmpty()) {
            throw new Exception("finaliseTestResults(): no test predictions stored for this module");
        }
        if (testClassVals.length != this.predictedClassValues.size()) {
            throw new Exception("finaliseTestResults(): Number of test predictions made and number of test cases do not match");
        }
        for (double d : testClassVals) {
            this.actualClassValues.add(d);
        }
        double correct = 0.0;
        for (int inst = 0; inst < this.predictedClassValues.size(); ++inst) {
            if (testClassVals[inst] != this.predictedClassValues.get(inst)) continue;
            correct += 1.0;
        }
        this.acc = correct / (double)testClassVals.length;
        this.finalised = true;
    }

    public int numInstances() {
        return this.predictedClassValues.size();
    }

    public double[] getTrueClassVals() {
        double[] d = new double[this.actualClassValues.size()];
        int i = 0;
        for (double x : this.actualClassValues) {
            d[i++] = x;
        }
        return d;
    }

    public double[] getPredClassVals() {
        double[] d = new double[this.predictedClassValues.size()];
        int i = 0;
        for (double x : this.predictedClassValues) {
            d[i++] = x;
        }
        return d;
    }

    public double getPredClassValue(int index) {
        return this.predictedClassValues.get(index);
    }

    public double getTrueClassValue(int index) {
        return this.actualClassValues.get(index);
    }

    public double[] getDistributionForInstance(int i) {
        if (i < this.predictedClassProbabilities.size()) {
            return this.predictedClassProbabilities.get(i);
        }
        return null;
    }

    public String writeInstancePredictions() {
        if (this.numInstances() > 0 && this.predictedClassProbabilities.size() == this.actualClassValues.size() && this.predictedClassProbabilities.size() == this.predictedClassValues.size()) {
            StringBuilder sb = new StringBuilder("");
            for (int i = 0; i < this.numInstances(); ++i) {
                double[] probs;
                sb.append(this.actualClassValues.get(i).intValue()).append(",");
                sb.append(this.predictedClassValues.get(i).intValue()).append(",");
                for (double d : probs = this.predictedClassProbabilities.get(i)) {
                    sb.append(",").append(GenericTools.RESULTS_DECIMAL_FORMAT.format(d));
                }
                if (i >= this.numInstances() - 1) continue;
                sb.append("\n");
            }
            return sb.toString();
        }
        return "No Instance Prediction Information";
    }

    public String writeResultsFileToString() throws IOException {
        StringBuilder st = new StringBuilder();
        st.append(this.name).append("\n");
        st.append("BuildTime,").append(this.buildTime).append(",").append(this.paras).append("\n");
        st.append(this.acc).append("\n");
        st.append(this.writeInstancePredictions());
        return st.toString();
    }

    public void loadFromFile(String path) throws FileNotFoundException {
        File f = new File(path);
        if (f.exists() && f.length() > 0L) {
            InFile inf = new InFile(path);
            this.name = inf.readLine();
            this.paras = inf.readLine();
            String[] parts = this.paras.split(",");
            if (parts.length > 0 && parts[0].contains("BuildTime")) {
                this.buildTime = (long)Double.parseDouble(parts[1].trim());
                if (parts.length > 2) {
                    this.paras = parts[2];
                    for (int i = 3; i < parts.length; ++i) {
                        this.paras = this.paras + "," + parts[i];
                    }
                }
            }
            double testAcc = inf.readDouble();
            String line = inf.readLine();
            this.actualClassValues = new ArrayList();
            this.predictedClassValues = new ArrayList();
            this.predictedClassProbabilities = new ArrayList();
            this.numInstances = 0;
            this.acc = 0.0;
            boolean firstLoop = true;
            while (line != null && !line.equals("")) {
                double b;
                double a;
                String[] split = line.split(",");
                if (split.length > 3) {
                    a = Double.valueOf(split[0]);
                    b = Double.valueOf(split[1]);
                    this.actualClassValues.add(a);
                    this.predictedClassValues.add(b);
                    if (a == b) {
                        this.acc += 1.0;
                    }
                    if (this.numInstances == 0) {
                        this.numClasses = split.length - 3;
                    }
                    double[] probs = new double[this.numClasses];
                    for (int i = 0; i < probs.length; ++i) {
                        probs[i] = Double.valueOf(split[3 + i].trim());
                    }
                    this.predictedClassProbabilities.add(probs);
                    ++this.numInstances;
                } else {
                    if (firstLoop) {
                        this.printlnDebug("WARNING: Results file does not contain probabilities, " + path);
                    }
                    a = Double.valueOf(split[0]);
                    b = Double.valueOf(split[1]);
                    this.actualClassValues.add(a);
                    this.predictedClassValues.add(b);
                    if (a == b) {
                        this.acc += 1.0;
                    }
                    ++this.numInstances;
                }
                line = inf.readLine();
                firstLoop = false;
            }
            this.acc /= (double)this.numInstances;
        } else {
            throw new FileNotFoundException("File " + path + " NOT FOUND");
        }
        this.finalised = true;
    }

    public void findAllStats() {
        int i;
        this.confusionMatrix = this.buildConfusionMatrix();
        this.countPerClass = new double[this.confusionMatrix.length];
        for (i = 0; i < this.actualClassValues.size(); ++i) {
            int n = this.actualClassValues.get(i).intValue();
            this.countPerClass[n] = this.countPerClass[n] + 1.0;
        }
        this.acc = 0.0;
        for (i = 0; i < this.numClasses; ++i) {
            this.acc += this.confusionMatrix[i][i];
        }
        this.acc /= (double)this.numInstances;
        this.balancedAcc = this.findBalancedAcc(this.confusionMatrix);
        this.f1 = this.findF1(this.confusionMatrix);
        this.nll = this.findNLL();
        this.meanAUROC = this.findMeanAUROC();
        this.mcc = this.computeMCC(this.confusionMatrix);
        this.allStatsFound = true;
    }

    public void findAllStatsOnce() {
        if (this.allStatsFound) {
            this.printlnDebug("Stats already found, ignoring findAllStatsOnce()");
            return;
        }
        this.findAllStats();
    }

    public double findNLL() {
        double nll = 0.0;
        for (int i = 0; i < this.actualClassValues.size(); ++i) {
            int trueClass;
            double[] dist = this.getDistributionForInstance(i);
            if (dist[trueClass = this.actualClassValues.get(i).intValue()] == 0.0) {
                nll += NLL_PENALTY;
                continue;
            }
            nll += Math.log(dist[trueClass]) / Math.log(2.0);
        }
        return -nll / (double)this.actualClassValues.size();
    }

    public double findMeanAUROC() {
        double a = 0.0;
        if (this.numClasses == 2) {
            a = this.findAUROC(1);
        } else {
            double[] classDist = InstanceTools.findClassDistributions(this.actualClassValues, this.numClasses);
            for (int i = 0; i < this.numClasses; ++i) {
                a += this.findAUROC(i) * classDist[i];
            }
        }
        return a;
    }

    public double computeMCC(double[][] confusionMatrix) {
        double num = 0.0;
        for (int k = 0; k < confusionMatrix.length; ++k) {
            for (int l = 0; l < confusionMatrix.length; ++l) {
                for (int m = 0; m < confusionMatrix.length; ++m) {
                    num += confusionMatrix[k][k] * confusionMatrix[m][l] - confusionMatrix[l][k] * confusionMatrix[k][m];
                }
            }
        }
        if (num == 0.0) {
            return 0.0;
        }
        double den1 = 0.0;
        double den2 = 0.0;
        for (int k = 0; k < confusionMatrix.length; ++k) {
            double den1Part1 = 0.0;
            double den2Part1 = 0.0;
            for (int l = 0; l < confusionMatrix.length; ++l) {
                den1Part1 += confusionMatrix[l][k];
                den2Part1 += confusionMatrix[k][l];
            }
            double den1Part2 = 0.0;
            double den2Part2 = 0.0;
            for (int kp = 0; kp < confusionMatrix.length; ++kp) {
                if (kp == k) continue;
                for (int lp = 0; lp < confusionMatrix.length; ++lp) {
                    den1Part2 += confusionMatrix[lp][kp];
                    den2Part2 += confusionMatrix[kp][lp];
                }
            }
            den1 += den1Part1 * den1Part2;
            den2 += den2Part1 * den2Part2;
        }
        return num / (Math.sqrt(den1) * Math.sqrt(den2));
    }

    public double findBalancedAcc(double[][] cm) {
        double[] accPerClass = new double[cm.length];
        for (int i = 0; i < cm.length; ++i) {
            accPerClass[i] = cm[i][i] / this.countPerClass[i];
        }
        double b = accPerClass[0];
        for (int i = 1; i < cm.length; ++i) {
            b += accPerClass[i];
        }
        return b /= (double)cm.length;
    }

    public double findF1(double[][] cm) {
        double f = 0.0;
        if (this.numClasses == 2) {
            f = this.countPerClass[0] < this.countPerClass[1] ? this.findConfusionMatrixStats(cm, 0, 1.0) : this.findConfusionMatrixStats(cm, 1, 1.0);
        } else {
            for (int i = 0; i < this.numClasses; ++i) {
                f += this.findConfusionMatrixStats(cm, i, 1.0);
            }
            f /= (double)this.numClasses;
        }
        return f;
    }

    protected double findConfusionMatrixStats(double[][] confMat, int c, double beta) {
        double tp = confMat[c][c];
        if (tp == 0.0) {
            return 1.0E-7;
        }
        double fp = 0.0;
        double fn = 0.0;
        double tn = 0.0;
        for (int i = 0; i < confMat.length; ++i) {
            if (i == c) continue;
            fp += confMat[i][c];
            fn += confMat[c][i];
            tn += confMat[i][i];
        }
        this.precision = tp / (tp + fp);
        this.sensitivity = this.recall = tp / (tp + fn);
        this.specificity = tn / (fp + tn);
        if (Double.compare(this.precision, Double.NaN) == 0) {
            this.precision = 0.0;
        }
        if (Double.compare(this.recall, Double.NaN) == 0) {
            this.recall = 0.0;
        }
        if (Double.compare(this.sensitivity, Double.NaN) == 0) {
            this.sensitivity = 0.0;
        }
        if (Double.compare(this.specificity, Double.NaN) == 0) {
            this.specificity = 0.0;
        }
        return (1.0 + beta * beta) * (this.precision * this.recall) / (beta * beta * this.precision + this.recall);
    }

    protected double findAUROC(int c) {
        class Pair
        implements Comparable<Pair> {
            Double x;
            Double y;

            public Pair(Double a, Double b) {
                this.x = a;
                this.y = b;
            }

            @Override
            public int compareTo(Pair p) {
                return p.x.compareTo(this.x);
            }

            public String toString() {
                return "(" + this.x + "," + this.y + ")";
            }
        }
        ArrayList<Pair> p = new ArrayList<Pair>();
        double nosPositive = 0.0;
        for (int i = 0; i < this.numInstances; ++i) {
            Pair temp = new Pair(this.predictedClassProbabilities.get(i)[c], this.actualClassValues.get(i));
            if ((double)c == this.actualClassValues.get(i)) {
                nosPositive += 1.0;
            }
            p.add(temp);
        }
        double nosNegative = (double)this.actualClassValues.size() - nosPositive;
        Collections.sort(p);
        ArrayList<Pair> roc = new ArrayList<Pair>();
        double x = 0.0;
        double oldX = 0.0;
        double y = 0.0;
        int xAdd = 0;
        int yAdd = 0;
        boolean xLast = false;
        boolean yLast = false;
        roc.add(new Pair(x, y));
        for (int i = 0; i < this.numInstances; ++i) {
            if (((Pair)p.get((int)i)).y == (double)c) {
                if (yLast) {
                    roc.add(new Pair(x, y));
                }
                xLast = true;
                yLast = false;
                x += 1.0 / nosPositive;
                if ((double)(++xAdd) != nosPositive) continue;
                x = 1.0;
                continue;
            }
            if (xLast) {
                roc.add(new Pair(x, y));
            }
            yLast = true;
            xLast = false;
            y += 1.0 / nosNegative;
            if ((double)(++yAdd) != nosNegative) continue;
            y = 1.0;
        }
        roc.add(new Pair(1.0, 1.0));
        double auroc = 0.0;
        for (int i = 0; i < roc.size() - 1; ++i) {
            auroc += (((Pair)roc.get((int)(i + 1))).y - ((Pair)roc.get((int)i)).y) * ((Pair)roc.get((int)(i + 1))).x;
        }
        return auroc;
    }

    public String writeAllStats() {
        int i;
        String str = "Acc," + this.acc + "\n";
        str = str + "BalancedAcc," + this.balancedAcc + "\n";
        str = str + "sensitivity," + this.sensitivity + "\n";
        str = str + "precision," + this.precision + "\n";
        str = str + "recall," + this.recall + "\n";
        str = str + "specificity," + this.specificity + "\n";
        str = str + "f1," + this.f1 + "\n";
        str = str + "mcc," + this.mcc + "\n";
        str = str + "nll," + this.nll + "\n";
        str = str + "meanAUROC," + this.meanAUROC + "\n";
        str = str + "stddev," + this.stddev + "\n";
        str = str + "Count per class:\n";
        for (i = 0; i < this.countPerClass.length; ++i) {
            str = str + "Class " + i + "," + this.countPerClass[i] + "\n";
        }
        str = str + "Confusion Matrix:\n";
        for (i = 0; i < this.confusionMatrix.length; ++i) {
            for (int j = 0; j < this.confusionMatrix[i].length; ++j) {
                str = str + this.confusionMatrix[i][j] + ",";
            }
            str = str + "\n";
        }
        return str;
    }

    boolean hasInstanceData() {
        return this.numInstances() != 0;
    }

    public static void main(String[] args) throws FileNotFoundException {
        String path = "C:\\JamesLPHD\\testFold1.csv";
        ClassifierResults cr = new ClassifierResults();
        cr.loadFromFile(path);
        cr.findAllStats();
        System.out.println("AUROC = " + cr.meanAUROC);
        System.out.println("FILE TEST =\n" + cr.writeAllStats());
        OutFile out = new OutFile("C:\\JamesLPHD\\testFold1stats.csv");
        out.writeLine(cr.writeAllStats());
    }
}

