/*
 * Decompiled with CFR 0.152.
 */
package timeseriesweka.classifiers;

import fileIO.OutFile;
import java.util.Random;
import timeseriesweka.classifiers.AbstractClassifierWithTrainingData;
import utilities.ClassifierResults;
import utilities.ClassifierTools;
import utilities.CrossValidator;
import utilities.SaveParameterInfo;
import utilities.TrainAccuracyEstimate;
import weka.classifiers.trees.RandomForest;
import weka.classifiers.trees.RandomTree;
import weka.core.Attribute;
import weka.core.DenseInstance;
import weka.core.FastVector;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.TechnicalInformation;

public class TSF
extends AbstractClassifierWithTrainingData
implements SaveParameterInfo,
TrainAccuracyEstimate {
    boolean setSeed = false;
    int seed = 0;
    RandomTree[] trees;
    int numTrees = 500;
    int numFeatures;
    int[][][] intervals;
    Random rand = new Random();
    Instances testHolder;
    boolean trainCV = false;
    String trainCVPath = "";
    static double[] reportedResults = new double[]{0.2659, 0.2302, 0.2333, 0.0256, 0.2537, 0.0391, 0.0357, 0.2897, 0.2, 0.2436, 0.049, 0.08, 0.0557, 0.2325, 0.0227, 0.101, 0.1543, 0.0467, 0.552, 0.6818, 0.0301, 0.1803, 0.2603, 0.0448, 0.2237, 0.119, 0.0987, 0.0865, 0.0667, 0.4339, 0.233, 0.1868, 0.0357, 0.1056, 0.1116, 0.0267, 0.02, 0.1177, 0.0543, 0.2102, 0.2876, 0.2624, 0.0054, 0.3793, 0.1513};
    static String[] problems = new String[]{"FiftyWords", "Adiac", "Beef", "CBF", "ChlorineConcentration", "CinCECGtorso", "Coffee", "CricketX", "CricketY", "CricketZ", "DiatomSizeReduction", "ECG", "ECGFiveDays", "FaceAll", "FaceFour", "FacesUCR", "Fish", "GunPoint", "Haptics", "InlineSkate", "ItalyPowerDemand", "Lightning2", "Lightning7", "Mallat", "MedicalImages", "MoteStrain", "NonInvasiveFetalECGThorax1", "NonInvasiveFetalECGThorax2", "OliveOil", "OSULeaf", "SonyAIBORobotSurface1", "SonyAIBORobot Surface2", "StarLightCurves", "SwedishLeaf", "Symbols", "Synthetic Control", "Trace", "TwoLeadECG", "TwoPatterns", "UWaveGestureLibraryX", "UWaveGestureLibraryY", "UWaveGestureLibraryZ", "Wafer", "WordsSynonyms", "Yoga"};

    public TSF() {
    }

    public TSF(int s) {
        this.seed = s;
        this.rand.setSeed(this.seed);
        this.setSeed = true;
    }

    public void setSeed(int s) {
        this.setSeed = true;
        this.seed = s;
        this.rand = new Random();
        this.rand.setSeed(this.seed);
    }

    @Override
    public void writeCVTrainToFile(String train) {
        this.trainCVPath = train;
        this.trainCV = true;
    }

    @Override
    public boolean findsTrainAccuracyEstimate() {
        return this.trainCV;
    }

    @Override
    public ClassifierResults getTrainResults() {
        return this.trainResults;
    }

    @Override
    public String getParameters() {
        return super.getParameters() + ",numTrees," + this.numTrees + ",numFeatures," + this.numFeatures;
    }

    public void setNumTrees(int t) {
        this.numTrees = t;
    }

    public TechnicalInformation getTechnicalInformation() {
        TechnicalInformation result = new TechnicalInformation(TechnicalInformation.Type.ARTICLE);
        result.setValue(TechnicalInformation.Field.AUTHOR, "H. Deng, G. Runger, E. Tuv and M. Vladimir");
        result.setValue(TechnicalInformation.Field.YEAR, "2013");
        result.setValue(TechnicalInformation.Field.TITLE, "A time series forest for classification and feature extraction");
        result.setValue(TechnicalInformation.Field.JOURNAL, "Information Sciences");
        result.setValue(TechnicalInformation.Field.VOLUME, "239");
        result.setValue(TechnicalInformation.Field.PAGES, "142-153");
        return result;
    }

    @Override
    public void buildClassifier(Instances data) throws Exception {
        long t1 = System.currentTimeMillis();
        this.numFeatures = (int)Math.sqrt(data.numAttributes() - 1);
        if (this.trainCV) {
            int numFolds = this.setNumberOfFolds(data);
            CrossValidator cv = new CrossValidator();
            if (this.setSeed) {
                cv.setSeed(this.seed);
            }
            cv.setNumFolds(numFolds);
            TSF tsf = new TSF();
            tsf.trainCV = false;
            this.trainResults = cv.crossValidateWithStats(tsf, data);
        }
        this.numFeatures = (int)Math.sqrt(data.numAttributes() - 1);
        this.intervals = new int[this.numTrees][][];
        this.trees = new RandomTree[this.numTrees];
        FastVector<Attribute> atts = new FastVector<Attribute>();
        for (int j = 0; j < this.numFeatures * 3; ++j) {
            String name = "F" + j;
            atts.addElement(new Attribute(name));
        }
        Attribute target = data.attribute(data.classIndex());
        FastVector<String> vals = new FastVector<String>(target.numValues());
        for (int j = 0; j < target.numValues(); ++j) {
            vals.addElement(target.value(j));
        }
        atts.addElement(new Attribute(data.attribute(data.classIndex()).name(), vals));
        Instances result = new Instances("Tree", atts, data.numInstances());
        result.setClassIndex(result.numAttributes() - 1);
        for (int i = 0; i < data.numInstances(); ++i) {
            DenseInstance in = new DenseInstance(result.numAttributes());
            in.setValue(result.numAttributes() - 1, data.instance(i).classValue());
            result.add(in);
        }
        this.testHolder = new Instances(result, 0);
        DenseInstance in = new DenseInstance(result.numAttributes());
        this.testHolder.add(in);
        for (int i = 0; i < this.numTrees; ++i) {
            int j;
            this.intervals[i] = new int[this.numFeatures][2];
            for (j = 0; j < this.numFeatures; ++j) {
                this.intervals[i][j][0] = this.rand.nextInt(data.numAttributes() - 1);
                int length = this.rand.nextInt(data.numAttributes() - 1 - this.intervals[i][j][0]);
                this.intervals[i][j][1] = this.intervals[i][j][0] + length;
            }
            for (j = 0; j < this.numFeatures; ++j) {
                for (int k = 0; k < data.numInstances(); ++k) {
                    double[] series = data.instance(k).toDoubleArray();
                    FeatureSet f = new FeatureSet();
                    f.setFeatures(series, this.intervals[i][j][0], this.intervals[i][j][1]);
                    result.instance(k).setValue(j * 3, f.mean);
                    result.instance(k).setValue(j * 3 + 1, f.stDev);
                    result.instance(k).setValue(j * 3 + 2, f.slope);
                }
            }
            this.trees[i] = new RandomTree();
            this.trees[i].setKValue(this.numFeatures);
            this.trees[i].buildClassifier(result);
        }
        long t2 = System.currentTimeMillis();
        this.trainResults.buildTime = t2 - t1;
        if (this.trainCVPath != "") {
            OutFile of = new OutFile(this.trainCVPath);
            of.writeLine(data.relationName() + ",TSF,train");
            of.writeLine(this.getParameters());
            of.writeLine(this.trainResults.acc + "");
            double[] trueClassVals = this.trainResults.getTrueClassVals();
            double[] predClassVals = this.trainResults.getPredClassVals();
            for (int i = 1; i < data.numInstances(); ++i) {
                if (data.instance(i).classValue() != trueClassVals[i]) {
                    throw new Exception("ERROR in TSF cross validation, class mismatch!");
                }
                of.writeString((int)trueClassVals[i] + "," + (int)predClassVals[i] + ",");
                for (double d : this.trainResults.getDistributionForInstance(i)) {
                    of.writeString("," + d);
                }
                of.writeString("\n");
            }
        }
    }

    @Override
    public double[] distributionForInstance(Instance ins) throws Exception {
        double[] d = new double[ins.numClasses()];
        int[] votes = new int[ins.numClasses()];
        double[] series = ins.toDoubleArray();
        for (int i = 0; i < this.trees.length; ++i) {
            int c;
            for (int j = 0; j < this.numFeatures; ++j) {
                Object f = new FeatureSet();
                ((FeatureSet)f).setFeatures(series, this.intervals[i][j][0], this.intervals[i][j][1]);
                this.testHolder.instance(0).setValue(j * 3, ((FeatureSet)f).mean);
                this.testHolder.instance(0).setValue(j * 3 + 1, ((FeatureSet)f).stDev);
                this.testHolder.instance(0).setValue(j * 3 + 2, ((FeatureSet)f).slope);
            }
            int n = c = (int)this.trees[i].classifyInstance(this.testHolder.instance(0));
            votes[n] = votes[n] + 1;
        }
        double sum = 0.0;
        for (int x : votes) {
            sum += (double)x;
        }
        for (int i = 0; i < d.length; ++i) {
            d[i] = (double)votes[i] / sum;
        }
        return d;
    }

    @Override
    public double classifyInstance(Instance ins) throws Exception {
        int[] votes = new int[ins.numClasses()];
        double[] series = ins.toDoubleArray();
        for (int i = 0; i < this.trees.length; ++i) {
            int c;
            for (int j = 0; j < this.numFeatures; ++j) {
                FeatureSet f = new FeatureSet();
                f.setFeatures(series, this.intervals[i][j][0], this.intervals[i][j][1]);
                this.testHolder.instance(0).setValue(j * 3, f.mean);
                this.testHolder.instance(0).setValue(j * 3 + 1, f.stDev);
                this.testHolder.instance(0).setValue(j * 3 + 2, f.slope);
            }
            int n = c = (int)this.trees[i].classifyInstance(this.testHolder.instance(0));
            votes[n] = votes[n] + 1;
        }
        int maxVote = 0;
        for (int i = 1; i < votes.length; ++i) {
            if (votes[i] <= votes[maxVote]) continue;
            maxVote = i;
        }
        return maxVote;
    }

    public static void main(String[] arg) throws Exception {
        FeatureSet f = new FeatureSet();
        double[] y = new double[]{0.0, 4.0, 8.0, 12.0, 16.0};
        f.setFeatures(y);
        System.out.println(f + "");
        FastVector atts = new FastVector();
        Instances train = ClassifierTools.loadData("C:\\Users\\ajb\\Dropbox\\TSC Problems\\ItalyPowerDemand\\ItalyPowerDemand_TRAIN");
        Instances test = ClassifierTools.loadData("C:\\Users\\ajb\\Dropbox\\TSC Problems\\ItalyPowerDemand\\ItalyPowerDemand_TEST");
        TSF tsf = new TSF();
        tsf.writeCVTrainToFile("C:\\Users\\ajb\\Dropbox\\Spectral Interval Experiments\\RIF\\Predictions\\InternalCV0.csv");
        tsf.buildClassifier(train);
        System.out.println("build ok: original atts=" + train.numAttributes() + " new atts =" + tsf.testHolder.numAttributes());
        double a = ClassifierTools.accuracy(test, tsf);
        System.out.println(" Accuracy =" + a);
    }

    public static class FeatureSet {
        double mean;
        double stDev;
        double slope;
        RandomForest r;

        public void setFeatures(double[] data, int start, int end) {
            double sumX = 0.0;
            double sumYY = 0.0;
            double sumY = 0.0;
            double sumXY = 0.0;
            double sumXX = 0.0;
            int length = end - start + 1;
            for (int i = start; i <= end; ++i) {
                sumY += data[i];
                sumYY += data[i] * data[i];
                sumX += (double)(i - start);
                sumXX += (double)((i - start) * (i - start));
                sumXY += data[i] * (double)(i - start);
            }
            this.mean = sumY / (double)length;
            this.stDev = sumYY - sumY * sumY / (double)length;
            this.slope = sumXY - sumX * sumY / (double)length;
            this.slope = sumXX - sumX * sumX / (double)length != 0.0 ? (this.slope /= sumXX - sumX * sumX / (double)length) : 0.0;
            this.stDev /= (double)length;
            if (this.stDev == 0.0) {
                this.slope = 0.0;
            }
            if (this.slope == 0.0) {
                this.stDev = 0.0;
            }
        }

        public void setFeatures(double[] data) {
            this.setFeatures(data, 0, data.length - 1);
        }

        public String toString() {
            return "mean=" + this.mean + " stdev = " + this.stDev + " slope =" + this.slope;
        }
    }
}

