/*
 * Decompiled with CFR 0.152.
 */
package weka.classifiers.mi;

import java.util.Enumeration;
import java.util.Random;
import java.util.Vector;
import weka.classifiers.RandomizableClassifier;
import weka.core.Capabilities;
import weka.core.FastVector;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.MultiInstanceCapabilitiesHandler;
import weka.core.Optimization;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.RevisionUtils;
import weka.core.SelectedTag;
import weka.core.Tag;
import weka.core.TechnicalInformation;
import weka.core.TechnicalInformationHandler;
import weka.core.Utils;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.Normalize;
import weka.filters.unsupervised.attribute.ReplaceMissingValues;
import weka.filters.unsupervised.attribute.Standardize;

public class MIEMDD
extends RandomizableClassifier
implements OptionHandler,
MultiInstanceCapabilitiesHandler,
TechnicalInformationHandler {
    static final long serialVersionUID = 3899547154866223734L;
    protected int m_ClassIndex;
    protected double[] m_Par;
    protected int m_NumClasses;
    protected int[] m_Classes;
    protected double[][][] m_Data;
    protected Instances m_Attributes;
    protected double[][] m_emData;
    protected Filter m_Filter = null;
    protected int m_filterType = 1;
    public static final int FILTER_NORMALIZE = 0;
    public static final int FILTER_STANDARDIZE = 1;
    public static final int FILTER_NONE = 2;
    public static final Tag[] TAGS_FILTER = new Tag[]{new Tag(0, "Normalize training data"), new Tag(1, "Standardize training data"), new Tag(2, "No normalization/standardization")};
    protected ReplaceMissingValues m_Missing = new ReplaceMissingValues();

    public String globalInfo() {
        return "EMDD model builds heavily upon Dietterich's Diverse Density (DD) algorithm.\nIt is a general framework for MI learning of converting the MI problem to a single-instance setting using EM. In this implementation, we use most-likely cause DD model and only use 3 random selected postive bags as initial starting points of EM.\n\nFor more information see:\n\n" + this.getTechnicalInformation().toString();
    }

    @Override
    public TechnicalInformation getTechnicalInformation() {
        TechnicalInformation result = new TechnicalInformation(TechnicalInformation.Type.INPROCEEDINGS);
        result.setValue(TechnicalInformation.Field.AUTHOR, "Qi Zhang and Sally A. Goldman");
        result.setValue(TechnicalInformation.Field.TITLE, "EM-DD: An Improved Multiple-Instance Learning Technique");
        result.setValue(TechnicalInformation.Field.BOOKTITLE, "Advances in Neural Information Processing Systems 14");
        result.setValue(TechnicalInformation.Field.YEAR, "2001");
        result.setValue(TechnicalInformation.Field.PAGES, "1073-108");
        result.setValue(TechnicalInformation.Field.PUBLISHER, "MIT Press");
        return result;
    }

    @Override
    public Enumeration listOptions() {
        Vector<Option> result = new Vector<Option>();
        result.addElement(new Option("\tWhether to 0=normalize/1=standardize/2=neither.\n\t(default 1=standardize)", "N", 1, "-N <num>"));
        Enumeration enm = super.listOptions();
        while (enm.hasMoreElements()) {
            result.addElement((Option)enm.nextElement());
        }
        return result.elements();
    }

    @Override
    public void setOptions(String[] options) throws Exception {
        String tmpStr = Utils.getOption('N', options);
        if (tmpStr.length() != 0) {
            this.setFilterType(new SelectedTag(Integer.parseInt(tmpStr), TAGS_FILTER));
        } else {
            this.setFilterType(new SelectedTag(1, TAGS_FILTER));
        }
        super.setOptions(options);
    }

    @Override
    public String[] getOptions() {
        Vector<String> result = new Vector<String>();
        String[] options = super.getOptions();
        for (int i = 0; i < options.length; ++i) {
            result.add(options[i]);
        }
        result.add("-N");
        result.add("" + this.m_filterType);
        return result.toArray(new String[result.size()]);
    }

    public String filterTypeTipText() {
        return "The filter type for transforming the training data.";
    }

    public SelectedTag getFilterType() {
        return new SelectedTag(this.m_filterType, TAGS_FILTER);
    }

    public void setFilterType(SelectedTag newType) {
        if (newType.getTags() == TAGS_FILTER) {
            this.m_filterType = newType.getSelectedTag().getID();
        }
    }

    @Override
    public Capabilities getCapabilities() {
        Capabilities result = super.getCapabilities();
        result.disableAll();
        result.enable(Capabilities.Capability.NOMINAL_ATTRIBUTES);
        result.enable(Capabilities.Capability.RELATIONAL_ATTRIBUTES);
        result.enable(Capabilities.Capability.MISSING_VALUES);
        result.enable(Capabilities.Capability.BINARY_CLASS);
        result.enable(Capabilities.Capability.MISSING_CLASS_VALUES);
        result.enable(Capabilities.Capability.ONLY_MULTIINSTANCE);
        return result;
    }

    @Override
    public Capabilities getMultiInstanceCapabilities() {
        Capabilities result = super.getCapabilities();
        result.disableAll();
        result.enable(Capabilities.Capability.NOMINAL_ATTRIBUTES);
        result.enable(Capabilities.Capability.NUMERIC_ATTRIBUTES);
        result.enable(Capabilities.Capability.DATE_ATTRIBUTES);
        result.enable(Capabilities.Capability.MISSING_VALUES);
        result.disableAllClasses();
        result.enable(Capabilities.Capability.NO_CLASS);
        return result;
    }

    @Override
    public void buildClassifier(Instances train) throws Exception {
        int n3;
        int n2;
        int n1;
        int i;
        this.getCapabilities().testWithFail(train);
        train = new Instances(train);
        train.deleteWithMissingClass();
        this.m_ClassIndex = train.classIndex();
        this.m_NumClasses = train.numClasses();
        int nR = train.attribute(1).relation().numAttributes();
        int nC = train.numInstances();
        int[] bagSize = new int[nC];
        Instances datasets = new Instances(train.attribute(1).relation(), 0);
        this.m_Data = new double[nC][nR][];
        this.m_Classes = new int[nC];
        this.m_Attributes = datasets.stringFreeStructure();
        if (this.m_Debug) {
            System.out.println("\n\nExtracting data...");
        }
        for (int h = 0; h < nC; ++h) {
            int nI;
            Instance current = train.instance(h);
            this.m_Classes[h] = (int)current.classValue();
            Instances currInsts = current.relationalValue(1);
            for (i = 0; i < currInsts.numInstances(); ++i) {
                Instance inst = currInsts.instance(i);
                datasets.add(inst);
            }
            bagSize[h] = nI = currInsts.numInstances();
        }
        this.m_Filter = this.m_filterType == 1 ? new Standardize() : (this.m_filterType == 0 ? new Normalize() : null);
        if (this.m_Filter != null) {
            this.m_Filter.setInputFormat(datasets);
            datasets = Filter.useFilter(datasets, this.m_Filter);
        }
        this.m_Missing.setInputFormat(datasets);
        datasets = Filter.useFilter(datasets, this.m_Missing);
        int instIndex = 0;
        int start = 0;
        for (int h = 0; h < nC; ++h) {
            for (i = 0; i < datasets.numAttributes(); ++i) {
                this.m_Data[h][i] = new double[bagSize[h]];
                instIndex = start;
                for (int k = 0; k < bagSize[h]; ++k) {
                    this.m_Data[h][i][k] = datasets.instance(instIndex).value(i);
                    ++instIndex;
                }
            }
            start = instIndex;
        }
        if (this.m_Debug) {
            System.out.println("\n\nIteration History...");
        }
        this.m_emData = new double[nC][nR];
        this.m_Par = new double[2 * nR];
        double[] x = new double[nR * 2];
        double[] tmp = new double[x.length];
        double[] pre_x = new double[x.length];
        double[] best_hypothesis = new double[x.length];
        double[][] b = new double[2][x.length];
        double bestnll = Double.MAX_VALUE;
        double min_error = Double.MAX_VALUE;
        for (int t = 0; t < x.length; ++t) {
            b[0][t] = Double.NaN;
            b[1][t] = Double.NaN;
        }
        Random r = new Random(this.getSeed());
        FastVector<Integer> index = new FastVector<Integer>();
        while (this.m_Classes[n1 = r.nextInt(nC - 1)] == 0) {
        }
        index.addElement(new Integer(n1));
        while ((n2 = r.nextInt(nC - 1)) == n1 || this.m_Classes[n2] == 0) {
        }
        index.addElement(new Integer(n2));
        while ((n3 = r.nextInt(nC - 1)) == n1 || n3 == n2 || this.m_Classes[n3] == 0) {
        }
        index.addElement(new Integer(n3));
        for (int s = 0; s < index.size(); ++s) {
            int exIdx = (Integer)index.elementAt(s);
            if (this.m_Debug) {
                System.out.println("\nH0 at " + exIdx);
            }
            for (int p = 0; p < this.m_Data[exIdx][0].length; ++p) {
                for (int q = 0; q < nR; ++q) {
                    x[2 * q] = this.m_Data[exIdx][q][p];
                    x[2 * q + 1] = 1.0;
                }
                double pre_nll = Double.MAX_VALUE;
                double nll = 1.7976931348623158E307;
                int iterationCount = 0;
                while (nll < pre_nll && iterationCount < 10) {
                    ++iterationCount;
                    pre_nll = nll;
                    if (this.m_Debug) {
                        System.out.println("\niteration: " + iterationCount);
                    }
                    for (int i2 = 0; i2 < this.m_Data.length; ++i2) {
                        int insIndex = this.findInstance(i2, x);
                        for (int att = 0; att < this.m_Data[0].length; ++att) {
                            this.m_emData[i2][att] = this.m_Data[i2][att][insIndex];
                        }
                    }
                    if (this.m_Debug) {
                        System.out.println("E-step for new H' finished");
                    }
                    OptEng opt = new OptEng();
                    tmp = opt.findArgmin(x, b);
                    while (tmp == null) {
                        tmp = opt.getVarbValues();
                        if (this.m_Debug) {
                            System.out.println("200 iterations finished, not enough!");
                        }
                        tmp = opt.findArgmin(tmp, b);
                    }
                    nll = opt.getMinFunction();
                    pre_x = x;
                    x = tmp;
                }
                double[] distribution = new double[2];
                int error = 0;
                this.m_Par = nll > pre_nll ? pre_x : x;
                for (int i3 = 0; i3 < train.numInstances(); ++i3) {
                    distribution = this.distributionForInstance(train.instance(i3));
                    if (distribution[1] >= 0.5 && this.m_Classes[i3] == 0) {
                        ++error;
                        continue;
                    }
                    if (!(distribution[1] < 0.5) || this.m_Classes[i3] != 1) continue;
                    ++error;
                }
                if (!((double)error < min_error)) continue;
                best_hypothesis = this.m_Par;
                min_error = error;
                bestnll = nll > pre_nll ? pre_nll : nll;
                if (!this.m_Debug) continue;
                System.out.println("error= " + error + "  nll= " + bestnll);
            }
            if (!this.m_Debug) continue;
            System.out.println(exIdx + ":  -------------<Converged>--------------");
            System.out.println("current minimum error= " + min_error + "  nll= " + bestnll);
        }
        this.m_Par = best_hypothesis;
    }

    protected int findInstance(int i, double[] x) {
        double min = Double.MAX_VALUE;
        int insIndex = 0;
        int nI = this.m_Data[i][0].length;
        for (int j = 0; j < nI; ++j) {
            double ins = 0.0;
            for (int k = 0; k < this.m_Data[i].length; ++k) {
                ins += (this.m_Data[i][k][j] - x[k * 2]) * (this.m_Data[i][k][j] - x[k * 2]) * x[k * 2 + 1] * x[k * 2 + 1];
            }
            if (!(ins < min)) continue;
            min = ins;
            insIndex = j;
        }
        return insIndex;
    }

    @Override
    public double[] distributionForInstance(Instance exmp) throws Exception {
        Instances ins = exmp.relationalValue(1);
        if (this.m_Filter != null) {
            ins = Filter.useFilter(ins, this.m_Filter);
        }
        ins = Filter.useFilter(ins, this.m_Missing);
        int nI = ins.numInstances();
        int nA = ins.numAttributes();
        double[][] dat = new double[nI][nA];
        for (int j = 0; j < nI; ++j) {
            for (int k = 0; k < nA; ++k) {
                dat[j][k] = ins.instance(j).value(k);
            }
        }
        double min = Double.MAX_VALUE;
        double maxProb = -1.0;
        for (int j = 0; j < nI; ++j) {
            double exp = 0.0;
            for (int k = 0; k < nA; ++k) {
                exp += (dat[j][k] - this.m_Par[k * 2]) * (dat[j][k] - this.m_Par[k * 2]) * this.m_Par[k * 2 + 1] * this.m_Par[k * 2 + 1];
            }
            if (!(exp < min)) continue;
            min = exp;
            maxProb = Math.exp(-exp);
        }
        double[] distribution = new double[2];
        distribution[1] = maxProb;
        distribution[0] = 1.0 - distribution[1];
        return distribution;
    }

    public String toString() {
        String result = "MIEMDD";
        if (this.m_Par == null) {
            return result + ": No model built yet.";
        }
        result = result + "\nCoefficients...\nVariable       Point       Scale\n";
        int j = 0;
        int idx = 0;
        while (j < this.m_Par.length / 2) {
            result = result + this.m_Attributes.attribute(idx).name();
            result = result + " " + Utils.doubleToString(this.m_Par[j * 2], 12, 4);
            result = result + " " + Utils.doubleToString(this.m_Par[j * 2 + 1], 12, 4) + "\n";
            ++j;
            ++idx;
        }
        return result;
    }

    @Override
    public String getRevision() {
        return RevisionUtils.extract("$Revision: 5546 $");
    }

    public static void main(String[] argv) {
        MIEMDD.runClassifier(new MIEMDD(), argv);
    }

    private class OptEng
    extends Optimization {
        private OptEng() {
        }

        @Override
        protected double objectiveFunction(double[] x) {
            double nll = 0.0;
            for (int i = 0; i < MIEMDD.this.m_Classes.length; ++i) {
                double ins = 0.0;
                for (int k = 0; k < MIEMDD.this.m_emData[i].length; ++k) {
                    ins += (MIEMDD.this.m_emData[i][k] - x[k * 2]) * (MIEMDD.this.m_emData[i][k] - x[k * 2]) * x[k * 2 + 1] * x[k * 2 + 1];
                }
                ins = Math.exp(-ins);
                if (MIEMDD.this.m_Classes[i] == 1) {
                    if (ins <= m_Zero) {
                        ins = m_Zero;
                    }
                    nll -= Math.log(ins);
                    continue;
                }
                if ((ins = 1.0 - ins) <= m_Zero) {
                    ins = m_Zero;
                }
                nll -= Math.log(ins);
            }
            return nll;
        }

        @Override
        protected double[] evaluateGradient(double[] x) {
            double[] grad = new double[x.length];
            for (int i = 0; i < MIEMDD.this.m_Classes.length; ++i) {
                double[] numrt = new double[x.length];
                double exp = 0.0;
                for (int k = 0; k < MIEMDD.this.m_emData[i].length; ++k) {
                    exp += (MIEMDD.this.m_emData[i][k] - x[k * 2]) * (MIEMDD.this.m_emData[i][k] - x[k * 2]) * x[k * 2 + 1] * x[k * 2 + 1];
                }
                exp = Math.exp(-exp);
                for (int p = 0; p < MIEMDD.this.m_emData[i].length; ++p) {
                    numrt[2 * p] = 2.0 * (x[2 * p] - MIEMDD.this.m_emData[i][p]) * x[p * 2 + 1] * x[p * 2 + 1];
                    numrt[2 * p + 1] = 2.0 * (x[2 * p] - MIEMDD.this.m_emData[i][p]) * (x[2 * p] - MIEMDD.this.m_emData[i][p]) * x[p * 2 + 1];
                }
                for (int q = 0; q < MIEMDD.this.m_emData[i].length; ++q) {
                    if (MIEMDD.this.m_Classes[i] == 1) {
                        int n = 2 * q;
                        grad[n] = grad[n] + numrt[2 * q];
                        int n2 = 2 * q + 1;
                        grad[n2] = grad[n2] + numrt[2 * q + 1];
                        continue;
                    }
                    int n = 2 * q;
                    grad[n] = grad[n] - numrt[2 * q] * exp / (1.0 - exp);
                    int n3 = 2 * q + 1;
                    grad[n3] = grad[n3] - numrt[2 * q + 1] * exp / (1.0 - exp);
                }
            }
            return grad;
        }

        @Override
        public String getRevision() {
            return RevisionUtils.extract("$Revision: 5546 $");
        }
    }
}

