/*
 * Decompiled with CFR 0.152.
 */
package timeseriesweka.filters;

import java.util.ArrayList;
import java.util.Collections;
import timeseriesweka.filters.shapelet_transforms.OrderLineObj;
import utilities.class_distributions.TreeSetClassDistribution;
import weka.core.Attribute;
import weka.core.DenseInstance;
import weka.core.FastVector;
import weka.core.Instances;
import weka.filters.SimpleBatchFilter;

public class BinaryTransform
extends SimpleBatchFilter {
    private boolean findNewSplits = true;
    private double[] splits;

    public void findNewSplits() {
        this.findNewSplits = true;
    }

    @Override
    protected Instances determineOutputFormat(Instances inputFormat) throws Exception {
        for (int i = 0; i < inputFormat.numAttributes(); ++i) {
            if (inputFormat.classIndex() == i || inputFormat.attribute(i).isNumeric()) continue;
            throw new Exception("Non numeric attribute not allowed in BinaryTransform");
        }
        int length = inputFormat.numAttributes();
        if (inputFormat.classIndex() >= 0) {
            --length;
        }
        FastVector<Attribute> atts = new FastVector<Attribute>();
        FastVector<String> attributeValues = new FastVector<String>();
        attributeValues.addElement("0");
        attributeValues.addElement("1");
        for (int i = 0; i < length; ++i) {
            String name = "Binary_" + i;
            atts.addElement(new Attribute(name, attributeValues));
        }
        if (inputFormat.classIndex() >= 0) {
            Attribute target = inputFormat.attribute(inputFormat.classIndex());
            FastVector<String> vals = new FastVector<String>(target.numValues());
            for (int i = 0; i < target.numValues(); ++i) {
                vals.addElement(target.value(i));
            }
            atts.addElement(new Attribute(inputFormat.attribute(inputFormat.classIndex()).name(), vals));
        }
        Instances result = new Instances("Binary" + inputFormat.relationName(), atts, inputFormat.numInstances());
        if (inputFormat.classIndex() >= 0) {
            result.setClassIndex(result.numAttributes() - 1);
        }
        return result;
    }

    @Override
    public Instances process(Instances data) throws Exception {
        Instances output = this.determineOutputFormat(data);
        if (this.findNewSplits) {
            this.splits = new double[data.numAttributes()];
            double[] classes = new double[data.numInstances()];
            for (int i = 0; i < classes.length; ++i) {
                classes[i] = data.instance(i).classValue();
            }
            for (int j = 0; j < data.numAttributes(); ++j) {
                if (j == data.classIndex()) continue;
                double[] vals = new double[data.numInstances()];
                for (int i = 0; i < data.numInstances(); ++i) {
                    vals[i] = data.instance(i).value(j);
                }
                this.splits[j] = this.findSplitValue(data, vals, classes);
            }
            this.findNewSplits = false;
        }
        for (int i = 0; i < data.numInstances(); ++i) {
            DenseInstance newInst = new DenseInstance(data.numAttributes());
            for (int j = 0; j < data.numAttributes(); ++j) {
                if (j != data.classIndex()) {
                    if (data.instance(i).value(j) < this.splits[j]) {
                        newInst.setValue(j, 0.0);
                        continue;
                    }
                    newInst.setValue(j, 1.0);
                    continue;
                }
                newInst.setValue(j, data.instance(i).classValue());
            }
            output.add(newInst);
        }
        return output;
    }

    public double findSplitValue(Instances data, double[] vals, double[] classes) {
        ArrayList<OrderLineObj> list = new ArrayList<OrderLineObj>();
        for (int i = 0; i < vals.length; ++i) {
            list.add(new OrderLineObj(vals[i], classes[i]));
        }
        TreeSetClassDistribution tree = new TreeSetClassDistribution(data);
        Collections.sort(list);
        return BinaryTransform.infoGainThreshold(list, tree);
    }

    private static double entropy(TreeSetClassDistribution classDistributions) {
        if (classDistributions.size() == 1) {
            return 0.0;
        }
        int total = 0;
        for (Double d : classDistributions.keySet()) {
            total += classDistributions.get(d);
        }
        ArrayList<Double> entropyParts = new ArrayList<Double>();
        for (Double d : classDistributions.keySet()) {
            double thisPart = (double)classDistributions.get(d) / (double)total;
            double toAdd = -thisPart * Math.log10(thisPart) / Math.log10(2.0);
            if (Double.isNaN(toAdd)) {
                toAdd = 0.0;
            }
            entropyParts.add(toAdd);
        }
        double d = 0.0;
        for (int i = 0; i < entropyParts.size(); ++i) {
            d += ((Double)entropyParts.get(i)).doubleValue();
        }
        return d;
    }

    public static double infoGainThreshold(ArrayList<OrderLineObj> orderline, TreeSetClassDistribution classDistribution) {
        double lastDist = orderline.get(0).getDistance();
        double thisDist = -1.0;
        double bsfGain = -1.0;
        double threshold = -1.0;
        for (int i = 1; i < orderline.size(); ++i) {
            thisDist = orderline.get(i).getDistance();
            if (i == 1 || thisDist != lastDist) {
                double entropyGreater;
                double greaterFrac;
                double entropyLess;
                double lessFrac;
                int storedTotal;
                double thisClassVal;
                int j;
                TreeSetClassDistribution lessClasses = new TreeSetClassDistribution();
                TreeSetClassDistribution greaterClasses = new TreeSetClassDistribution();
                for (double j2 : classDistribution.keySet()) {
                    lessClasses.put(j2, 0);
                    greaterClasses.put(j2, 0);
                }
                int sumOfLessClasses = 0;
                int sumOfGreaterClasses = 0;
                for (j = 0; j < i; ++j) {
                    thisClassVal = orderline.get(j).getClassVal();
                    storedTotal = lessClasses.get(thisClassVal);
                    lessClasses.put(thisClassVal, ++storedTotal);
                    ++sumOfLessClasses;
                }
                for (j = i; j < orderline.size(); ++j) {
                    thisClassVal = orderline.get(j).getClassVal();
                    storedTotal = greaterClasses.get(thisClassVal);
                    greaterClasses.put(thisClassVal, ++storedTotal);
                    ++sumOfGreaterClasses;
                }
                int sumOfAllClasses = sumOfLessClasses + sumOfGreaterClasses;
                double parentEntropy = BinaryTransform.entropy(classDistribution);
                double gain = parentEntropy - (lessFrac = (double)sumOfLessClasses / (double)sumOfAllClasses) * (entropyLess = BinaryTransform.entropy(lessClasses)) - (greaterFrac = (double)sumOfGreaterClasses / (double)sumOfAllClasses) * (entropyGreater = BinaryTransform.entropy(greaterClasses));
                if (gain > bsfGain) {
                    bsfGain = gain;
                    threshold = (thisDist - lastDist) / 2.0 + lastDist;
                }
            }
            lastDist = thisDist;
        }
        return threshold;
    }

    @Override
    public String globalInfo() {
        throw new UnsupportedOperationException("Not supported yet.");
    }
}

