/*
 * Decompiled with CFR 0.152.
 */
package utilities;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Random;
import utilities.ClassifierResults;
import utilities.ClassifierTools;
import utilities.StatisticalUtilities;
import weka.classifiers.Classifier;
import weka.core.Instances;

public class CrossValidator {
    private Integer seed = null;
    private int numFolds = 10;
    private ArrayList<Instances> folds = null;
    private ArrayList<ArrayList<Integer>> foldIndexing = null;

    public ArrayList<ArrayList<Integer>> getFoldIndices() {
        return this.foldIndexing;
    }

    public Integer getSeed() {
        return this.seed;
    }

    public void setSeed(Integer seed) {
        this.seed = seed;
    }

    public int getNumFolds() {
        return this.numFolds;
    }

    public void setNumFolds(int numFolds) {
        this.numFolds = numFolds;
    }

    public int getOriginalInstIndex(int fold, int indexInFold) {
        return this.foldIndexing.get(fold).get(indexInFold);
    }

    public ClassifierResults crossValidateWithStats(Classifier classifier, Instances train) throws Exception {
        return this.crossValidateWithStats(new Classifier[]{classifier}, train)[0];
    }

    public ClassifierResults[] crossValidateWithStats(Classifier[] classifiers, Instances train) throws Exception {
        long time = System.currentTimeMillis();
        if (this.folds == null) {
            this.buildFolds(train);
        }
        double[][] predictions = new double[classifiers.length][train.numInstances()];
        double[][][] distsForInsts = new double[classifiers.length][train.numInstances()][];
        double[][] foldaccs = new double[classifiers.length][this.numFolds];
        double[] classifierAccs = new double[classifiers.length];
        for (int testFold = 0; testFold < this.numFolds; ++testFold) {
            Instances[] trainTest = this.buildTrainTestSet(testFold);
            for (int c = 0; c < classifiers.length; ++c) {
                classifiers[c].buildClassifier(trainTest[0]);
                for (int i = 0; i < trainTest[1].numInstances(); ++i) {
                    int instIndex = this.getOriginalInstIndex(testFold, i);
                    double[] dist = classifiers[c].distributionForInstance(trainTest[1].instance(i));
                    double pred = this.indexOfMax(dist);
                    distsForInsts[c][instIndex] = dist;
                    predictions[c][instIndex] = pred;
                    if (pred != trainTest[1].instance(i).classValue()) continue;
                    double[] dArray = foldaccs[c];
                    int n = testFold;
                    dArray[n] = dArray[n] + 1.0;
                    int n2 = c;
                    classifierAccs[n2] = classifierAccs[n2] + 1.0;
                }
                double[] dArray = foldaccs[c];
                int n = testFold;
                dArray[n] = dArray[n] / (double)trainTest[1].numInstances();
            }
        }
        ClassifierResults[] results = new ClassifierResults[classifiers.length];
        double[] classVals = train.attributeToDoubleArray(train.classIndex());
        long t2 = System.currentTimeMillis();
        for (int c = 0; c < classifiers.length; ++c) {
            int n = c;
            classifierAccs[n] = classifierAccs[n] / (double)predictions[c].length;
            double stddevOverFolds = StatisticalUtilities.standardDeviation(foldaccs[c], false, classifierAccs[c]);
            results[c] = new ClassifierResults(classifierAccs[c], classVals, predictions[c], distsForInsts[c], stddevOverFolds, train.numClasses());
            results[c].buildTime = t2 - time;
        }
        return results;
    }

    public double[] crossValidate(Classifier classifier, Instances train) throws Exception {
        return this.crossValidate(new Classifier[]{classifier}, train)[0];
    }

    public double[][] crossValidate(Classifier[] classifiers, Instances train) throws Exception {
        if (this.folds == null) {
            this.buildFolds(train);
        }
        double[][] predictions = new double[classifiers.length][train.numInstances()];
        for (int testFold = 0; testFold < this.numFolds; ++testFold) {
            Instances[] trainTest = this.buildTrainTestSet(testFold);
            for (int c = 0; c < classifiers.length; ++c) {
                classifiers[c].buildClassifier(trainTest[0]);
                for (int i = 0; i < trainTest[1].numInstances(); ++i) {
                    double pred;
                    predictions[c][this.getOriginalInstIndex((int)testFold, (int)i)] = pred = classifiers[c].classifyInstance(trainTest[1].instance(i));
                }
            }
        }
        return predictions;
    }

    public Instances[] buildTrainTestSet(int testFold) {
        Instances[] trainTest = new Instances[]{null, new Instances(this.folds.get(testFold))};
        for (int f = 0; f < this.folds.size(); ++f) {
            if (f == testFold) continue;
            Instances temp = new Instances(this.folds.get(f));
            if (trainTest[0] == null) {
                trainTest[0] = temp;
                continue;
            }
            trainTest[0].addAll(temp);
        }
        return trainTest;
    }

    public void buildFolds(Instances train) throws Exception {
        int i;
        train = new Instances(train);
        Random r = null;
        r = this.seed != null ? new Random(this.seed.intValue()) : new Random();
        this.folds = new ArrayList();
        this.foldIndexing = new ArrayList();
        for (int i2 = 0; i2 < this.numFolds; ++i2) {
            this.folds.add(new Instances(train, 0));
            this.foldIndexing.add(new ArrayList());
        }
        ArrayList<Integer> instanceIds = new ArrayList<Integer>();
        for (int i3 = 0; i3 < train.numInstances(); ++i3) {
            instanceIds.add(i3);
        }
        Collections.shuffle(instanceIds, r);
        ArrayList<Instances> byClass = new ArrayList<Instances>();
        ArrayList byClassIndices = new ArrayList();
        for (i = 0; i < train.numClasses(); ++i) {
            byClass.add(new Instances(train, 0));
            byClassIndices.add(new ArrayList());
        }
        for (i = 0; i < instanceIds.size(); ++i) {
            int instIndex = (Integer)instanceIds.get(i);
            int instClassVal = (int)train.instance(instIndex).classValue();
            ((Instances)byClass.get(instClassVal)).add(train.instance(instIndex));
            ((ArrayList)byClassIndices.get(instClassVal)).add(instIndex);
        }
        ArrayList sortedByClassInstanceIds = new ArrayList();
        for (int c = 0; c < train.numClasses(); ++c) {
            sortedByClassInstanceIds.addAll((Collection)byClassIndices.get(c));
        }
        int start = 0;
        for (int fold = 0; fold < this.numFolds; ++fold) {
            for (int i4 = start; i4 < train.numInstances(); i4 += this.numFolds) {
                this.folds.get(fold).add(train.instance((Integer)sortedByClassInstanceIds.get(i4)));
                this.foldIndexing.get(fold).add((Integer)sortedByClassInstanceIds.get(i4));
            }
            ++start;
        }
    }

    private double indexOfMax(double[] dist) {
        double bsfWeight = -1.7976931348623157E308;
        ArrayList<Integer> bsfClassVals = null;
        for (int c = 0; c < dist.length; ++c) {
            if (dist[c] > bsfWeight) {
                bsfWeight = dist[c];
                bsfClassVals = new ArrayList<Integer>();
                bsfClassVals.add(c);
                continue;
            }
            if (dist[c] != bsfWeight) continue;
            bsfClassVals.add(c);
        }
        double pred = bsfClassVals.size() > 1 ? (double)((Integer)bsfClassVals.get(new Random(0L).nextInt(bsfClassVals.size()))).intValue() : (double)((Integer)bsfClassVals.get(0)).intValue();
        return pred;
    }

    public static void main(String[] args) throws Exception {
        CrossValidator.buildFoldsTest();
    }

    public static void buildFoldsTest() throws Exception {
        int j;
        CrossValidator cv = new CrossValidator();
        cv.setNumFolds(3);
        cv.setSeed(0);
        String dset = "lenses";
        Instances insts = ClassifierTools.loadData("C:/UCI Problems/" + dset + "/" + dset);
        System.out.println("Full data:");
        System.out.println("numinsts=" + insts.numInstances());
        int[] classCounts = new int[insts.numClasses()];
        double[] classDists = new double[insts.numClasses()];
        for (j = 0; j < insts.numInstances(); ++j) {
            int n = (int)insts.get(j).classValue();
            classCounts[n] = classCounts[n] + 1;
        }
        for (j = 0; j < insts.numClasses(); ++j) {
            classDists[j] = (double)classCounts[j] / (double)insts.numInstances();
        }
        System.out.println("classcounts= " + Arrays.toString(classCounts));
        System.out.println("classdist=   " + Arrays.toString(classDists));
        cv.buildFolds(insts);
        for (int i = 0; i < cv.numFolds; ++i) {
            int j2;
            Instances fold = cv.folds.get(i);
            System.out.println("\nFold " + i);
            System.out.println("numinsts=" + fold.numInstances());
            int[] classCount = new int[insts.numClasses()];
            double[] classDist = new double[fold.numClasses()];
            for (j2 = 0; j2 < fold.numInstances(); ++j2) {
                int n = (int)fold.get(j2).classValue();
                classCount[n] = classCount[n] + 1;
            }
            for (j2 = 0; j2 < fold.numClasses(); ++j2) {
                classDist[j2] = (double)classCount[j2] / (double)fold.numInstances();
            }
            System.out.println("classcounts= " + Arrays.toString(classCount));
            System.out.println("classdist=   " + Arrays.toString(classDist));
            Collections.sort((List)cv.foldIndexing.get(i));
            System.out.println("(sorted) orginal indices: " + cv.foldIndexing.get(i));
            System.out.println("");
        }
    }
}

