/*
 * Decompiled with CFR 0.152.
 */
package timeseriesweka.classifiers;

import java.io.File;
import java.math.BigDecimal;
import java.math.BigInteger;
import java.math.MathContext;
import timeseriesweka.classifiers.ContractClassifier;
import timeseriesweka.classifiers.cote.HiveCoteModule;
import timeseriesweka.filters.shapelet_transforms.Shapelet;
import timeseriesweka.filters.shapelet_transforms.ShapeletTransform;
import timeseriesweka.filters.shapelet_transforms.ShapeletTransformFactory;
import timeseriesweka.filters.shapelet_transforms.ShapeletTransformFactoryOptions;
import timeseriesweka.filters.shapelet_transforms.ShapeletTransformTimingUtilities;
import timeseriesweka.filters.shapelet_transforms.distance_functions.SubSeqDistance;
import timeseriesweka.filters.shapelet_transforms.quality_measures.ShapeletQuality;
import timeseriesweka.filters.shapelet_transforms.search_functions.ShapeletSearch;
import timeseriesweka.filters.shapelet_transforms.search_functions.ShapeletSearchOptions;
import utilities.ClassifierResults;
import utilities.ClassifierTools;
import utilities.InstanceTools;
import utilities.SaveParameterInfo;
import utilities.TrainAccuracyEstimate;
import vector_classifiers.HESCA;
import weka.classifiers.AbstractClassifier;
import weka.core.Instance;
import weka.core.Instances;

public class ST_HESCA
extends AbstractClassifier
implements HiveCoteModule,
SaveParameterInfo,
TrainAccuracyEstimate,
ContractClassifier {
    public static final int minimumRepresentation = 25;
    private boolean preferShortShapelets = false;
    private String shapeletOutputPath;
    private HESCA hesca;
    private ShapeletTransform transform;
    private Instances format;
    int[] redundantFeatures;
    private boolean doTransform = true;
    private ShapeletSearch.SearchType searchType = ShapeletSearch.SearchType.IMP_RANDOM;
    private long numShapelets = 0L;
    private long seed = 0L;
    private long timeLimit = Long.MAX_VALUE;

    public void setSeed(long sd) {
        this.seed = sd;
    }

    public void setSearchType(ShapeletSearch.SearchType type) {
        this.searchType = type;
    }

    @Override
    public void writeCVTrainToFile(String train) {
        this.hesca.writeCVTrainToFile(train);
    }

    @Override
    public ClassifierResults getTrainResults() {
        return this.hesca.getTrainResults();
    }

    @Override
    public String getParameters() {
        String paras = this.transform.getParameters();
        String ensemble = this.hesca.getParameters();
        return paras + ",timeLimit" + this.timeLimit + "," + ensemble;
    }

    @Override
    public double getEnsembleCvAcc() {
        return this.hesca.getEnsembleCvAcc();
    }

    @Override
    public double[] getEnsembleCvPreds() {
        return this.hesca.getEnsembleCvPreds();
    }

    public void doSTransform(boolean b) {
        this.doTransform = b;
    }

    public long getTransformOpCount() {
        return this.transform.getCount();
    }

    public Instances transformDataset(Instances data) {
        if (this.transform.isFirstBatchDone()) {
            return this.transform.process(data);
        }
        return null;
    }

    @Override
    public void setTimeLimit(long time) {
        this.timeLimit = time;
    }

    @Override
    public void setTimeLimit(ContractClassifier.TimeLimit time, int amount) {
        long[] times = new long[]{60000000000L, 3600000000000L, 86400000000000L};
        this.timeLimit = times[time.ordinal()] * (long)amount;
    }

    public void setNumberOfShapelets(long numS) {
        this.numShapelets = numS;
    }

    @Override
    public void buildClassifier(Instances data) throws Exception {
        this.format = this.doTransform ? this.createTransformData(data, this.timeLimit) : data;
        this.hesca = new HESCA();
        this.hesca.setRandSeed((int)this.seed);
        this.redundantFeatures = InstanceTools.removeRedundantTrainAttributes(this.format);
        this.hesca.buildClassifier(this.format);
        this.format = new Instances(data, 0);
    }

    @Override
    public double classifyInstance(Instance ins) throws Exception {
        this.format.add(ins);
        Instances temp = this.doTransform ? this.transform.process(this.format) : this.format;
        for (int del : this.redundantFeatures) {
            temp.deleteAttributeAt(del);
        }
        Instance test = temp.get(0);
        this.format.remove(0);
        return this.hesca.classifyInstance(test);
    }

    @Override
    public double[] distributionForInstance(Instance ins) throws Exception {
        this.format.add(ins);
        Instances temp = this.doTransform ? this.transform.process(this.format) : this.format;
        for (int del : this.redundantFeatures) {
            temp.deleteAttributeAt(del);
        }
        Instance test = temp.get(0);
        this.format.remove(0);
        return this.hesca.distributionForInstance(test);
    }

    public void setShapeletOutputFilePath(String path) {
        this.shapeletOutputPath = path;
    }

    public void preferShortShapelets() {
        this.preferShortShapelets = true;
    }

    public Instances createTransformData(Instances train, long time) {
        int n = train.numInstances();
        int m = train.numAttributes() - 1;
        ShapeletTransformFactoryOptions.Builder optionsBuilder = new ShapeletTransformFactoryOptions.Builder();
        optionsBuilder.setDistanceType(SubSeqDistance.DistanceType.IMP_ONLINE);
        optionsBuilder.setQualityMeasure(ShapeletQuality.ShapeletQualityChoice.INFORMATION_GAIN);
        if (train.numClasses() > 2) {
            optionsBuilder.useBinaryClassValue();
            optionsBuilder.useClassBalancing();
        }
        optionsBuilder.useRoundRobin();
        optionsBuilder.useCandidatePruning();
        ShapeletSearchOptions.Builder searchBuilder = new ShapeletSearchOptions.Builder();
        searchBuilder.setMin(3);
        searchBuilder.setMax(m);
        int K = n > 2000 ? 2000 : n;
        BigInteger opCountTarget = new BigInteger(Long.toString(time / 10L));
        BigInteger opCount = ShapeletTransformTimingUtilities.calculateOps(n, m, 1, 1);
        if (opCount.compareTo(opCountTarget) == 1) {
            BigDecimal oct = new BigDecimal(opCountTarget);
            BigDecimal oc = new BigDecimal(opCount);
            BigDecimal prop = oct.divide(oc, MathContext.DECIMAL64);
            if (this.numShapelets == 0L) {
                this.numShapelets = ShapeletTransformTimingUtilities.calculateNumberOfShapelets(n, m, 3, m);
                this.numShapelets = (long)((double)this.numShapelets * prop.doubleValue());
            }
            searchBuilder.setSeed(this.seed);
            searchBuilder.setSearchType(this.searchType);
            searchBuilder.setNumShapelets(this.numShapelets);
            K = this.numShapelets > (long)K ? K : (int)this.numShapelets;
        }
        optionsBuilder.setKShapelets(K);
        optionsBuilder.setSearchOptions(searchBuilder.build());
        this.transform = new ShapeletTransformFactory(optionsBuilder.build()).getTransform();
        this.transform.supressOutput();
        if (this.shapeletOutputPath != null) {
            this.transform.setLogOutputFile(this.shapeletOutputPath);
        }
        if (this.preferShortShapelets) {
            this.transform.setShapeletComparator(new Shapelet.ShortOrder());
        }
        return this.transform.process(train);
    }

    public static void main(String[] args) throws Exception {
        String dataLocation = "C:\\LocalData\\Dropbox\\TSC Problems\\";
        String saveLocation = "..\\..\\resampled results\\RefinedRandomTransform\\";
        String datasetName = "Earthquakes";
        boolean fold = false;
        Instances train = ClassifierTools.loadData(dataLocation + datasetName + File.separator + datasetName + "_TRAIN");
        Instances test = ClassifierTools.loadData(dataLocation + datasetName + File.separator + datasetName + "_TEST");
        String trainS = saveLocation + datasetName + File.separator + "TrainCV.csv";
        String testS = saveLocation + datasetName + File.separator + "TestPreds.csv";
        String preds = saveLocation + datasetName;
        ST_HESCA st = new ST_HESCA();
        st.doSTransform(true);
        st.setOneMinuteLimit();
        st.buildClassifier(train);
        double accuracy = ClassifierTools.accuracy(test, st);
        System.out.println("accuracy: " + accuracy);
    }
}

