/*
 * Decompiled with CFR 0.152.
 */
package weka.classifiers.functions.pace;

import java.util.Random;
import weka.classifiers.functions.pace.MixtureDistribution;
import weka.classifiers.functions.pace.PaceMatrix;
import weka.core.RevisionUtils;
import weka.core.matrix.DoubleVector;
import weka.core.matrix.Maths;

public class NormalMixture
extends MixtureDistribution {
    protected double separatingThreshold = 0.05;
    protected double trimingThreshold = 0.7;
    protected double fittingIntervalLength = 3.0;

    public double getSeparatingThreshold() {
        return this.separatingThreshold;
    }

    public void setSeparatingThreshold(double t) {
        this.separatingThreshold = t;
    }

    public double getTrimingThreshold() {
        return this.trimingThreshold;
    }

    public void setTrimingThreshold(double t) {
        this.trimingThreshold = t;
    }

    @Override
    public boolean separable(DoubleVector data, int i0, int i1, double x) {
        double p = 0.0;
        for (int i = i0; i <= i1; ++i) {
            p += Maths.pnorm(-Math.abs(x - data.get(i)));
        }
        return p < this.separatingThreshold;
    }

    @Override
    public DoubleVector supportPoints(DoubleVector data, int ne) {
        if (data.size() < 2) {
            throw new IllegalArgumentException("data size < 2");
        }
        return data.copy();
    }

    @Override
    public PaceMatrix fittingIntervals(DoubleVector data) {
        DoubleVector left = data.cat(data.minus(this.fittingIntervalLength));
        DoubleVector right = data.plus(this.fittingIntervalLength).cat(data);
        PaceMatrix a = new PaceMatrix(left.size(), 2);
        a.setMatrix(0, left.size() - 1, 0, left);
        a.setMatrix(0, right.size() - 1, 1, right);
        return a;
    }

    @Override
    public PaceMatrix probabilityMatrix(DoubleVector s, PaceMatrix intervals) {
        int ns = s.size();
        int nr = intervals.getRowDimension();
        PaceMatrix p = new PaceMatrix(nr, ns);
        for (int i = 0; i < nr; ++i) {
            for (int j = 0; j < ns; ++j) {
                p.set(i, j, Maths.pnorm(intervals.get(i, 1), s.get(j), 1.0) - Maths.pnorm(intervals.get(i, 0), s.get(j), 1.0));
            }
        }
        return p;
    }

    public double empiricalBayesEstimate(double x) {
        if (Math.abs(x) > 10.0) {
            return x;
        }
        DoubleVector d = Maths.dnormLog(x, this.mixingDistribution.getPointValues(), 1.0);
        d.minusEquals(d.max());
        d = d.map("java.lang.Math", "exp");
        d.timesEquals(this.mixingDistribution.getFunctionValues());
        return this.mixingDistribution.getPointValues().innerProduct(d) / d.sum();
    }

    public DoubleVector empiricalBayesEstimate(DoubleVector x) {
        DoubleVector pred = new DoubleVector(x.size());
        for (int i = 0; i < x.size(); ++i) {
            pred.set(i, this.empiricalBayesEstimate(x.get(i)));
        }
        this.trim(pred);
        return pred;
    }

    public DoubleVector nestedEstimate(DoubleVector x) {
        DoubleVector chf = new DoubleVector(x.size());
        for (int i = 0; i < x.size(); ++i) {
            chf.set(i, this.hf(x.get(i)));
        }
        chf.cumulateInPlace();
        int index = chf.indexOfMax();
        DoubleVector copy = x.copy();
        if (index < x.size() - 1) {
            copy.set(index + 1, x.size() - 1, 0.0);
        }
        this.trim(copy);
        return copy;
    }

    public DoubleVector subsetEstimate(DoubleVector x) {
        DoubleVector h = this.h(x);
        DoubleVector copy = x.copy();
        for (int i = 0; i < x.size(); ++i) {
            if (!(h.get(i) <= 0.0)) continue;
            copy.set(i, 0.0);
        }
        this.trim(copy);
        return copy;
    }

    public void trim(DoubleVector x) {
        for (int i = 0; i < x.size(); ++i) {
            if (!(Math.abs(x.get(i)) <= this.trimingThreshold)) continue;
            x.set(i, 0.0);
        }
    }

    public double hf(double x) {
        DoubleVector points = this.mixingDistribution.getPointValues();
        DoubleVector values = this.mixingDistribution.getFunctionValues();
        DoubleVector d = Maths.dnormLog(x, points, 1.0);
        d.minusEquals(d.max());
        d = d.map("java.lang.Math", "exp");
        d.timesEquals(values);
        return points.times(2.0 * x).minusEquals(x * x).innerProduct(d) / d.sum();
    }

    public double h(double x) {
        DoubleVector points = this.mixingDistribution.getPointValues();
        DoubleVector values = this.mixingDistribution.getFunctionValues();
        DoubleVector d = Maths.dnorm(x, points, 1.0).timesEquals(values);
        return points.times(2.0 * x).minusEquals(x * x).innerProduct(d);
    }

    public DoubleVector h(DoubleVector x) {
        DoubleVector h = new DoubleVector(x.size());
        for (int i = 0; i < x.size(); ++i) {
            h.set(i, this.h(x.get(i)));
        }
        return h;
    }

    public double f(double x) {
        DoubleVector points = this.mixingDistribution.getPointValues();
        DoubleVector values = this.mixingDistribution.getFunctionValues();
        return Maths.dchisq(x, points).timesEquals(values).sum();
    }

    public DoubleVector f(DoubleVector x) {
        DoubleVector f = new DoubleVector(x.size());
        for (int i = 0; i < x.size(); ++i) {
            f.set(i, this.h(f.get(i)));
        }
        return f;
    }

    @Override
    public String toString() {
        return this.mixingDistribution.toString();
    }

    @Override
    public String getRevision() {
        return RevisionUtils.extract("$Revision: 1.5 $");
    }

    public static void main(String[] args) {
        int n1 = 50;
        int n2 = 50;
        double mu1 = 0.0;
        double mu2 = 5.0;
        DoubleVector a = Maths.rnorm(n1, mu1, 1.0, new Random());
        a = a.cat(Maths.rnorm(n2, mu2, 1.0, new Random()));
        DoubleVector means = new DoubleVector(n1, mu1).cat(new DoubleVector(n2, mu2));
        System.out.println("==========================================================");
        System.out.println("This is to test the estimation of the mixing\ndistribution of the mixture of unit variance normal\ndistributions. The example mixture used is of the form: \n\n   0.5 * N(mu1, 1) + 0.5 * N(mu2, 1)\n");
        System.out.println("It also tests three estimators: the subset\nselector, the nested model selector, and the empirical Bayes\nestimator. Quadratic losses of the estimators are given, \nand are taken as the measure of their performance.");
        System.out.println("==========================================================");
        System.out.println("mu1 = " + mu1 + " mu2 = " + mu2 + "\n");
        System.out.println(a.size() + " observations are: \n\n" + a);
        System.out.println("\nQuadratic loss of the raw data (i.e., the MLE) = " + a.sum2(means));
        System.out.println("==========================================================");
        NormalMixture d = new NormalMixture();
        d.fit(a, 1);
        System.out.println("The estimated mixing distribution is:\n" + d);
        DoubleVector pred = d.nestedEstimate(a.rev()).rev();
        System.out.println("\nThe Nested Estimate = \n" + pred);
        System.out.println("Quadratic loss = " + pred.sum2(means));
        pred = d.subsetEstimate(a);
        System.out.println("\nThe Subset Estimate = \n" + pred);
        System.out.println("Quadratic loss = " + pred.sum2(means));
        pred = d.empiricalBayesEstimate(a);
        System.out.println("\nThe Empirical Bayes Estimate = \n" + pred);
        System.out.println("Quadratic loss = " + pred.sum2(means));
    }
}

