/*
 * Decompiled with CFR 0.152.
 */
package weka.classifiers.meta;

import java.util.Enumeration;
import java.util.Random;
import java.util.Vector;
import weka.classifiers.RandomizableIteratedSingleClassifierEnhancer;
import weka.classifiers.trees.J48;
import weka.core.Attribute;
import weka.core.DenseInstance;
import weka.core.FastVector;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.Randomizable;
import weka.core.RevisionUtils;
import weka.core.TechnicalInformation;
import weka.core.TechnicalInformationHandler;
import weka.core.Utils;
import weka.core.WeightedInstancesHandler;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.Normalize;
import weka.filters.unsupervised.attribute.PrincipalComponents;
import weka.filters.unsupervised.attribute.RemoveUseless;
import weka.filters.unsupervised.instance.RemovePercentage;

public class RotationForest
extends RandomizableIteratedSingleClassifierEnhancer
implements WeightedInstancesHandler,
TechnicalInformationHandler {
    static final long serialVersionUID = -3255631880798499936L;
    protected int m_MinGroup = 3;
    protected int m_MaxGroup = 3;
    protected boolean m_NumberOfGroups = false;
    protected int m_RemovedPercentage = 50;
    protected int[][][] m_Groups = null;
    protected Filter m_ProjectionFilter = null;
    protected Filter[][] m_ProjectionFilters = null;
    protected Instances[] m_Headers = null;
    protected Instances[][] m_ReducedHeaders = null;
    protected RemoveUseless m_RemoveUseless = null;
    protected Normalize m_Normalize = null;

    public RotationForest() {
        this.m_Classifier = new J48();
        this.m_ProjectionFilter = this.defaultFilter();
    }

    protected Filter defaultFilter() {
        PrincipalComponents filter = new PrincipalComponents();
        filter.setVarianceCovered(1.0);
        return filter;
    }

    public String globalInfo() {
        return "Class for construction a Rotation Forest. Can do classification and regression depending on the base learner. \n\nFor more information, see\n\n" + this.getTechnicalInformation().toString();
    }

    @Override
    public TechnicalInformation getTechnicalInformation() {
        TechnicalInformation result = new TechnicalInformation(TechnicalInformation.Type.ARTICLE);
        result.setValue(TechnicalInformation.Field.AUTHOR, "Juan J. Rodriguez and Ludmila I. Kuncheva and Carlos J. Alonso");
        result.setValue(TechnicalInformation.Field.YEAR, "2006");
        result.setValue(TechnicalInformation.Field.TITLE, "Rotation Forest: A new classifier ensemble method");
        result.setValue(TechnicalInformation.Field.JOURNAL, "IEEE Transactions on Pattern Analysis and Machine Intelligence");
        result.setValue(TechnicalInformation.Field.VOLUME, "28");
        result.setValue(TechnicalInformation.Field.NUMBER, "10");
        result.setValue(TechnicalInformation.Field.PAGES, "1619-1630");
        result.setValue(TechnicalInformation.Field.ISSN, "0162-8828");
        result.setValue(TechnicalInformation.Field.URL, "http://doi.ieeecomputersociety.org/10.1109/TPAMI.2006.211");
        return result;
    }

    @Override
    protected String defaultClassifierString() {
        return "weka.classifiers.trees.J48";
    }

    @Override
    public Enumeration listOptions() {
        Vector<Option> newVector = new Vector<Option>(5);
        newVector.addElement(new Option("\tWhether minGroup (-G) and maxGroup (-H) refer to\n\tthe number of groups or their size.\n\t(default: false)", "N", 0, "-N"));
        newVector.addElement(new Option("\tMinimum size of a group of attributes:\n\t\tif numberOfGroups is true, the minimum number\n\t\tof groups.\n\t\t(default: 3)", "G", 1, "-G <num>"));
        newVector.addElement(new Option("\tMaximum size of a group of attributes:\n\t\tif numberOfGroups is true, the maximum number\n\t\tof groups.\n\t\t(default: 3)", "H", 1, "-H <num>"));
        newVector.addElement(new Option("\tPercentage of instances to be removed.\n\t\t(default: 50)", "P", 1, "-P <num>"));
        newVector.addElement(new Option("\tFull class name of filter to use, followed\n\tby filter options.\n\teg: \"weka.filters.unsupervised.attribute.PrincipalComponents-R 1.0\"", "F", 1, "-F <filter specification>"));
        Enumeration enu = super.listOptions();
        while (enu.hasMoreElements()) {
            newVector.addElement((Option)enu.nextElement());
        }
        return newVector.elements();
    }

    @Override
    public void setOptions(String[] options) throws Exception {
        String filterString = Utils.getOption('F', options);
        if (filterString.length() > 0) {
            String[] filterSpec = Utils.splitOptions(filterString);
            if (filterSpec.length == 0) {
                throw new IllegalArgumentException("Invalid filter specification string");
            }
            String filterName = filterSpec[0];
            filterSpec[0] = "";
            this.setProjectionFilter((Filter)Utils.forName(Filter.class, filterName, filterSpec));
        } else {
            this.setProjectionFilter(this.defaultFilter());
        }
        String tmpStr = Utils.getOption('G', options);
        if (tmpStr.length() != 0) {
            this.setMinGroup(Integer.parseInt(tmpStr));
        } else {
            this.setMinGroup(3);
        }
        tmpStr = Utils.getOption('H', options);
        if (tmpStr.length() != 0) {
            this.setMaxGroup(Integer.parseInt(tmpStr));
        } else {
            this.setMaxGroup(3);
        }
        tmpStr = Utils.getOption('P', options);
        if (tmpStr.length() != 0) {
            this.setRemovedPercentage(Integer.parseInt(tmpStr));
        } else {
            this.setRemovedPercentage(50);
        }
        this.setNumberOfGroups(Utils.getFlag('N', options));
        super.setOptions(options);
    }

    @Override
    public String[] getOptions() {
        String[] superOptions = super.getOptions();
        String[] options = new String[superOptions.length + 9];
        int current = 0;
        if (this.getNumberOfGroups()) {
            options[current++] = "-N";
        }
        options[current++] = "-G";
        options[current++] = "" + this.getMinGroup();
        options[current++] = "-H";
        options[current++] = "" + this.getMaxGroup();
        options[current++] = "-P";
        options[current++] = "" + this.getRemovedPercentage();
        options[current++] = "-F";
        options[current++] = this.getProjectionFilterSpec();
        System.arraycopy(superOptions, 0, options, current, superOptions.length);
        current += superOptions.length;
        while (current < options.length) {
            options[current++] = "";
        }
        return options;
    }

    public String numberOfGroupsTipText() {
        return "Whether minGroup and maxGroup refer to the number of groups or their size.";
    }

    public void setNumberOfGroups(boolean numberOfGroups) {
        this.m_NumberOfGroups = numberOfGroups;
    }

    public boolean getNumberOfGroups() {
        return this.m_NumberOfGroups;
    }

    public String minGroupTipText() {
        return "Minimum size of a group (if numberOfGrups is true, the minimum number of groups.";
    }

    public void setMinGroup(int minGroup) throws IllegalArgumentException {
        if (minGroup <= 0) {
            throw new IllegalArgumentException("MinGroup has to be positive.");
        }
        this.m_MinGroup = minGroup;
    }

    public int getMinGroup() {
        return this.m_MinGroup;
    }

    public String maxGroupTipText() {
        return "Maximum size of a group (if numberOfGrups is true, the maximum number of groups.";
    }

    public void setMaxGroup(int maxGroup) throws IllegalArgumentException {
        if (maxGroup <= 0) {
            throw new IllegalArgumentException("MaxGroup has to be positive.");
        }
        this.m_MaxGroup = maxGroup;
    }

    public int getMaxGroup() {
        return this.m_MaxGroup;
    }

    public String removedPercentageTipText() {
        return "The percentage of instances to be removed.";
    }

    public void setRemovedPercentage(int removedPercentage) throws IllegalArgumentException {
        if (removedPercentage < 0) {
            throw new IllegalArgumentException("RemovedPercentage has to be >=0.");
        }
        if (removedPercentage >= 100) {
            throw new IllegalArgumentException("RemovedPercentage has to be <100.");
        }
        this.m_RemovedPercentage = removedPercentage;
    }

    public int getRemovedPercentage() {
        return this.m_RemovedPercentage;
    }

    public String projectionFilterTipText() {
        return "The filter used to project the data (e.g., PrincipalComponents).";
    }

    public void setProjectionFilter(Filter projectionFilter) {
        this.m_ProjectionFilter = projectionFilter;
    }

    public Filter getProjectionFilter() {
        return this.m_ProjectionFilter;
    }

    protected String getProjectionFilterSpec() {
        Filter c = this.getProjectionFilter();
        if (c instanceof OptionHandler) {
            return c.getClass().getName() + " " + Utils.joinOptions(((OptionHandler)((Object)c)).getOptions());
        }
        return c.getClass().getName();
    }

    public String toString() {
        if (this.m_Classifiers == null) {
            return "RotationForest: No model built yet.";
        }
        StringBuffer text = new StringBuffer();
        text.append("All the base classifiers: \n\n");
        for (int i = 0; i < this.m_Classifiers.length; ++i) {
            text.append(this.m_Classifiers[i].toString() + "\n\n");
        }
        return text.toString();
    }

    @Override
    public String getRevision() {
        return RevisionUtils.extract("$Revision: 7012 $");
    }

    @Override
    public void buildClassifier(Instances data) throws Exception {
        this.getCapabilities().testWithFail(data);
        data = new Instances(data);
        super.buildClassifier(data);
        this.checkMinMax(data);
        Random random = data.numInstances() > 0 ? data.getRandomNumberGenerator(this.m_Seed) : new Random(this.m_Seed);
        this.m_RemoveUseless = new RemoveUseless();
        this.m_RemoveUseless.setInputFormat(data);
        data = Filter.useFilter(data, this.m_RemoveUseless);
        this.m_Normalize = new Normalize();
        this.m_Normalize.setInputFormat(data);
        data = Filter.useFilter(data, this.m_Normalize);
        if (this.m_NumberOfGroups) {
            this.generateGroupsFromNumbers(data, random);
        } else {
            this.generateGroupsFromSizes(data, random);
        }
        this.m_ProjectionFilters = new Filter[this.m_Groups.length][];
        for (int i = 0; i < this.m_ProjectionFilters.length; ++i) {
            this.m_ProjectionFilters[i] = Filter.makeCopies(this.m_ProjectionFilter, this.m_Groups[i].length);
        }
        int numClasses = data.numClasses();
        Instances[] instancesOfClass = new Instances[numClasses + 1];
        if (data.classAttribute().isNumeric()) {
            instancesOfClass = new Instances[numClasses];
            instancesOfClass[0] = data;
        } else {
            instancesOfClass = new Instances[numClasses + 1];
            for (int i = 0; i < instancesOfClass.length; ++i) {
                instancesOfClass[i] = new Instances(data, 0);
            }
            Enumeration enu = data.enumerateInstances();
            while (enu.hasMoreElements()) {
                Instance instance = (Instance)enu.nextElement();
                if (instance.classIsMissing()) {
                    instancesOfClass[numClasses].add(instance);
                    continue;
                }
                int c = (int)instance.classValue();
                instancesOfClass[c].add(instance);
            }
            if (instancesOfClass[numClasses].numInstances() == 0) {
                Instances[] tmp = instancesOfClass;
                instancesOfClass = new Instances[numClasses];
                System.arraycopy(tmp, 0, instancesOfClass, 0, numClasses);
            }
        }
        this.m_Headers = new Instances[this.m_Classifiers.length];
        this.m_ReducedHeaders = new Instances[this.m_Classifiers.length][];
        for (int i = 0; i < this.m_Classifiers.length; ++i) {
            this.m_ReducedHeaders[i] = new Instances[this.m_Groups[i].length];
            FastVector<Attribute> transformedAttributes = new FastVector<Attribute>(data.numAttributes());
            for (int j = 0; j < this.m_Groups[i].length; ++j) {
                FastVector<Attribute> fv = new FastVector<Attribute>(this.m_Groups[i][j].length + 1);
                for (int k = 0; k < this.m_Groups[i][j].length; ++k) {
                    String newName = data.attribute(this.m_Groups[i][j][k]).name() + "_" + k;
                    fv.addElement(data.attribute(this.m_Groups[i][j][k]).copy(newName));
                }
                fv.addElement((Attribute)data.classAttribute().copy());
                Instances dataSubSet = new Instances("rotated-" + i + "-" + j + "-", fv, 0);
                dataSubSet.setClassIndex(dataSubSet.numAttributes() - 1);
                this.m_ReducedHeaders[i][j] = new Instances(dataSubSet, 0);
                boolean[] selectedClasses = this.selectClasses(instancesOfClass.length, random);
                for (int c = 0; c < selectedClasses.length; ++c) {
                    if (!selectedClasses[c]) continue;
                    Enumeration enu = instancesOfClass[c].enumerateInstances();
                    while (enu.hasMoreElements()) {
                        Instance instance = (Instance)enu.nextElement();
                        DenseInstance newInstance = new DenseInstance(dataSubSet.numAttributes());
                        newInstance.setDataset(dataSubSet);
                        for (int k = 0; k < this.m_Groups[i][j].length; ++k) {
                            newInstance.setValue(k, instance.value(this.m_Groups[i][j][k]));
                        }
                        newInstance.setClassValue(instance.classValue());
                        dataSubSet.add(newInstance);
                    }
                }
                dataSubSet.randomize(random);
                Instances originalDataSubSet = dataSubSet;
                dataSubSet.randomize(random);
                RemovePercentage rp = new RemovePercentage();
                rp.setPercentage(this.m_RemovedPercentage);
                rp.setInputFormat(dataSubSet);
                dataSubSet = Filter.useFilter(dataSubSet, rp);
                if (dataSubSet.numInstances() < 2) {
                    dataSubSet = originalDataSubSet;
                }
                this.m_ProjectionFilters[i][j].setInputFormat(dataSubSet);
                Instances projectedData = null;
                do {
                    try {
                        projectedData = Filter.useFilter(dataSubSet, this.m_ProjectionFilters[i][j]);
                    }
                    catch (Exception e) {
                        this.addRandomInstances(dataSubSet, 10, random);
                    }
                } while (projectedData == null);
                for (int a = 0; a < projectedData.numAttributes() - 1; ++a) {
                    String newName = projectedData.attribute(a).name() + "_" + j;
                    transformedAttributes.addElement(projectedData.attribute(a).copy(newName));
                }
            }
            transformedAttributes.addElement((Attribute)data.classAttribute().copy());
            Instances buildClas = new Instances("rotated-" + i + "-", transformedAttributes, 0);
            buildClas.setClassIndex(buildClas.numAttributes() - 1);
            this.m_Headers[i] = new Instances(buildClas, 0);
            Enumeration enu = data.enumerateInstances();
            while (enu.hasMoreElements()) {
                Instance instance = (Instance)enu.nextElement();
                Instance newInstance = this.convertInstance(instance, i);
                buildClas.add(newInstance);
            }
            if (this.m_Classifier instanceof Randomizable) {
                ((Randomizable)((Object)this.m_Classifiers[i])).setSeed(random.nextInt());
            }
            this.m_Classifiers[i].buildClassifier(buildClas);
        }
        if (this.m_Debug) {
            this.printGroups();
        }
    }

    protected void addRandomInstances(Instances dataset, int numInstances, Random random) {
        int n = dataset.numAttributes();
        double[] v = new double[n];
        for (int i = 0; i < numInstances; ++i) {
            for (int j = 0; j < n; ++j) {
                Attribute att = dataset.attribute(j);
                if (att.isNumeric()) {
                    v[j] = random.nextDouble();
                    continue;
                }
                if (!att.isNominal()) continue;
                v[j] = random.nextInt(att.numValues());
            }
            dataset.add(new DenseInstance(1.0, v));
        }
    }

    protected void checkMinMax(Instances data) {
        int n;
        if (this.m_MinGroup > this.m_MaxGroup) {
            int tmp = this.m_MaxGroup;
            this.m_MaxGroup = this.m_MinGroup;
            this.m_MinGroup = tmp;
        }
        if (this.m_MaxGroup >= (n = data.numAttributes())) {
            this.m_MaxGroup = n - 1;
        }
        if (this.m_MinGroup >= n) {
            this.m_MinGroup = n - 1;
        }
    }

    protected boolean[] selectClasses(int numClasses, Random random) {
        int numSelected = 0;
        boolean[] selected = new boolean[numClasses];
        for (int i = 0; i < selected.length; ++i) {
            if (!random.nextBoolean()) continue;
            selected[i] = true;
            ++numSelected;
        }
        if (numSelected == 0) {
            selected[random.nextInt((int)selected.length)] = true;
        }
        return selected;
    }

    protected void generateGroupsFromSizes(Instances data, Random random) {
        this.m_Groups = new int[this.m_Classifiers.length][][];
        for (int i = 0; i < this.m_Classifiers.length; ++i) {
            int[] permutation = this.attributesPermutation(data.numAttributes(), data.classIndex(), random);
            int[] numGroupsOfSize = new int[this.m_MaxGroup - this.m_MinGroup + 1];
            int numAttributes = 0;
            int numGroups = 0;
            while (numAttributes < permutation.length) {
                int n;
                int n2 = n = random.nextInt(numGroupsOfSize.length);
                numGroupsOfSize[n2] = numGroupsOfSize[n2] + 1;
                numAttributes += this.m_MinGroup + n;
                ++numGroups;
            }
            this.m_Groups[i] = new int[numGroups][];
            int currentAttribute = 0;
            int currentSize = 0;
            for (int j = 0; j < numGroups; ++j) {
                while (numGroupsOfSize[currentSize] == 0) {
                    ++currentSize;
                }
                int n = currentSize;
                numGroupsOfSize[n] = numGroupsOfSize[n] - 1;
                int n3 = this.m_MinGroup + currentSize;
                this.m_Groups[i][j] = new int[n3];
                for (int k = 0; k < n3; ++k) {
                    this.m_Groups[i][j][k] = currentAttribute < permutation.length ? permutation[currentAttribute] : permutation[random.nextInt(permutation.length)];
                    ++currentAttribute;
                }
            }
        }
    }

    protected void generateGroupsFromNumbers(Instances data, Random random) {
        this.m_Groups = new int[this.m_Classifiers.length][][];
        for (int i = 0; i < this.m_Classifiers.length; ++i) {
            int[] permutation = this.attributesPermutation(data.numAttributes(), data.classIndex(), random);
            int numGroups = this.m_MinGroup + random.nextInt(this.m_MaxGroup - this.m_MinGroup + 1);
            this.m_Groups[i] = new int[numGroups][];
            int groupSize = permutation.length / numGroups;
            int numBiggerGroups = permutation.length % numGroups;
            int currentAttribute = 0;
            for (int j = 0; j < numGroups; ++j) {
                this.m_Groups[i][j] = j < numBiggerGroups ? new int[groupSize + 1] : new int[groupSize];
                for (int k = 0; k < this.m_Groups[i][j].length; ++k) {
                    this.m_Groups[i][j][k] = permutation[currentAttribute++];
                }
            }
        }
    }

    protected int[] attributesPermutation(int numAttributes, int classAttribute, Random random) {
        int i;
        int[] permutation = new int[numAttributes - 1];
        for (i = 0; i < classAttribute; ++i) {
            permutation[i] = i;
        }
        while (i < permutation.length) {
            permutation[i] = i + 1;
            ++i;
        }
        this.permute(permutation, random);
        return permutation;
    }

    protected void permute(int[] v, Random random) {
        for (int i = v.length - 1; i > 0; --i) {
            int j = random.nextInt(i + 1);
            if (i == j) continue;
            int tmp = v[i];
            v[i] = v[j];
            v[j] = tmp;
        }
    }

    protected void printGroups() {
        for (int i = 0; i < this.m_Groups.length; ++i) {
            for (int j = 0; j < this.m_Groups[i].length; ++j) {
                System.err.print("( ");
                for (int k = 0; k < this.m_Groups[i][j].length; ++k) {
                    System.err.print(this.m_Groups[i][j][k]);
                    System.err.print(" ");
                }
                System.err.print(") ");
            }
            System.err.println();
        }
    }

    protected Instance convertInstance(Instance instance, int i) throws Exception {
        DenseInstance newInstance = new DenseInstance(this.m_Headers[i].numAttributes());
        newInstance.setWeight(instance.weight());
        newInstance.setDataset(this.m_Headers[i]);
        int currentAttribute = 0;
        for (int j = 0; j < this.m_Groups[i].length; ++j) {
            int k;
            Instance auxInstance = new DenseInstance(this.m_Groups[i][j].length + 1);
            for (k = 0; k < this.m_Groups[i][j].length; ++k) {
                auxInstance.setValue(k, instance.value(this.m_Groups[i][j][k]));
            }
            auxInstance.setValue(k, instance.classValue());
            auxInstance.setDataset(this.m_ReducedHeaders[i][j]);
            this.m_ProjectionFilters[i][j].input(auxInstance);
            auxInstance = this.m_ProjectionFilters[i][j].output();
            this.m_ProjectionFilters[i][j].batchFinished();
            for (int a = 0; a < auxInstance.numAttributes() - 1; ++a) {
                newInstance.setValue(currentAttribute++, auxInstance.value(a));
            }
        }
        newInstance.setClassValue(instance.classValue());
        return newInstance;
    }

    @Override
    public double[] distributionForInstance(Instance instance) throws Exception {
        this.m_RemoveUseless.input(instance);
        instance = this.m_RemoveUseless.output();
        this.m_RemoveUseless.batchFinished();
        this.m_Normalize.input(instance);
        instance = this.m_Normalize.output();
        this.m_Normalize.batchFinished();
        double[] sums = new double[instance.numClasses()];
        for (int i = 0; i < this.m_Classifiers.length; ++i) {
            Instance convertedInstance = this.convertInstance(instance, i);
            if (instance.classAttribute().isNumeric()) {
                sums[0] = sums[0] + this.m_Classifiers[i].classifyInstance(convertedInstance);
                continue;
            }
            double[] newProbs = this.m_Classifiers[i].distributionForInstance(convertedInstance);
            for (int j = 0; j < newProbs.length; ++j) {
                int n = j;
                sums[n] = sums[n] + newProbs[j];
            }
        }
        if (instance.classAttribute().isNumeric()) {
            sums[0] = sums[0] / (double)this.m_NumIterations;
            return sums;
        }
        if (Utils.eq(Utils.sum(sums), 0.0)) {
            return sums;
        }
        Utils.normalize(sums);
        return sums;
    }

    public static void main(String[] argv) {
        RotationForest.runClassifier(new RotationForest(), argv);
    }
}

