Gen.jl icon indicating copy to clipboard operation
Gen.jl copied to clipboard

Any example to implement Gaussian Mixture Model (GMM)?

Open zhixuedu opened this issue 5 years ago • 0 comments

It's been great to learn Gen and its power and flexibility in the past month. For my research, I am hoping to construct a Gaussian Mixture Model using Gen. It doesn't seem like there is an worked example to illustrate how it works. I am really a beginner in terms of writing my own inference engine, could anyone show me a hint? I started out with a simple GMM model (generative function), but I really don't have any clue how to the rest... I borrow the "dirichlet distribution" from @Kenta426 in #225 (very nice work). Here is my code to start:

using Gen, Distributions
using PyPlot
pygui(true)

## generate syntheic dataset y, with 3 peaks centered at 0., 4., 20.
no_obs = 1000
y= rand(MixtureModel(Normal[
   Normal(0., 1.),
   Normal(4., 1.),
   Normal(20., 1.)], [0.2, 0.5, 0.3]), no_obs)

plt.figure(1)
plt.hist(y, bins = 50)


K =20  # the number of mixtures trying to model

@gen function model_GMM()
    means = @trace(broadcasted_normal(mean(y), 10 .*ones(K)), :means)  # priors for each mixtures' mean
    sd = @trace(broadcasted_normal(20, 4 .*ones(K)), :sd)              # priors for each mixtures' std
    p = @trace(dirichlet(ones(K)), :p)                          # priors for each mixtures' weight (probability)
    for i = 1: no_obs
        z = @trace(categorical(p), (:z, i))
        @trace(normal(means[z], sd[z]), (:y, i))
    end
end

trace_GMM = Gen.simulate(model_GMM, ())

# How to do the inferences?

As I am looking through PyMC3 documents, it seems like the key is to sample the discrete value "z", either through step method "pm.ElemwiseCategorical()" or marginalize over z (https://docs.pymc.io/notebooks/gaussian_mixture_model.html). In Turing.jl, a particle Gibbs engine was used to achieve good results (Ge et al., 2018). Any help to set up an inference method would be greatly appreciated!

zhixuedu avatar Jun 13 '20 11:06 zhixuedu