Turing.jl icon indicating copy to clipboard operation
Turing.jl copied to clipboard

Problem with using `predict` with vector valued random variables

Open SamuelBrand1 opened this issue 1 year ago • 18 comments

Hi everyone,

Problem

There seems to be a problem with using predict in conjunction with models that use vectorisation.

Consider this fairly simple example:

\begin{split}
\sigma \sim \text{HalfNormal}(0.1) \\
\mu_i \sim \mathcal{N}(0, 1),\qquad i = 1,...,n \\
\epsilon_i \sim \mathcal{N}(0, \sigma^2),\qquad i = 1,...,n\\
x_i \sim \mu_i + \epsilon_i ,\qquad i = 1,...,n
\end{split}

We can generate a dataset by sampling from this model for (say) $n = 10$. The forecasting problem is then sampling for $n_f = 11,...,20$ (only information propagated forward is about variance of noise).

However, this fails as per below:

using Turing, StatsPlots, DynamicPPL, Random

Random.seed!(1234)

@model function mv_normal(n)
	σ ~ truncated(Normal(0., 0.1), lower = 0.)
	μ ~ MvNormal(n, 1.0) # Means
	x ~ MvNormal(μ, σ) # noise
	return x
end

mdl_10 = mv_normal(10)

# Sample data
x_data = mdl_10()

# infer means and obs noise
chn = sample(mdl_10 | (x = x_data,), NUTS(), 2_000)

# forecast
forecast_mdl = mv_normal(20)

forecast_chn = predict(forecast_mdl, chn; include_all = true)

let
	obs = generated_quantities(forecast_mdl, forecast_chn) |> X -> reduce(hcat, X)
	plt = plot(obs, c = :grey, alpha = 0.05, lab = "")
	scatter!(plt, x_data, c = :red, lab = "observed", title = "BAD FORECAST", ms = 6)
end

The failure mode here seems to be that the sample underlying random variables for $\epsilon_i,~ i = 11,...,20$ gets drawn across samples from from chn.

Fix 1: mapreduce across forecast calls

So if you instead loop over samples and run forecast for each sample, this seems to work:

forecast_chn_mapreduce = mapreduce(vcat, 1:size(chn, 1)) do i
	c = predict(forecast_mdl, chn[i,:,1]; include_all = true)
        # Take care to set the range sequentially
	setrange(c, i:i)
end

let
	obs = generated_quantities(forecast_mdl, forecast_chn_mapreduce) |> X -> reduce(hcat, X)
	plt = plot(obs, c = :grey, alpha = 0.05, lab = "")
	scatter!(plt, x_data, c = :black, lab = "observed", title = "OK FORECAST")
end

Fix 2: Non-vectorised sampling

Or you can modify the underlying model to not use calls to vectorised random variables (although IMO this is non-ideal).

@model function mv_normal_2(n)
	σ ~ truncated(Normal(0., 0.1), lower = 0.)
	μ = Vector{eltype(σ)}(undef, n)
	for i = 1:n
		μ[i] ~ Normal()
	end
	x ~ MvNormal(μ, σ) # noise
	return x
end

mdl2_10 = mv_normal_2(10)
x_data2 = mdl2_10()
chn2 = sample(mdl2_10 | (x = x_data2,), NUTS(), 2_000)


forecast_mdl2 = mv_normal_2(20)
forecast_chn2 = predict(forecast_mdl2, chn2; include_all = true)

let
	obs = generated_quantities(forecast_mdl2, forecast_chn2) |> X -> reduce(hcat, X)
	plt = plot(obs, c = :grey, alpha = 0.05, lab = "")
	scatter!(plt, x_data2, c = :black, lab = "observed", title = "ALSO OK FORECAST?")
end

Ideal situation

Obviously, it would be ideal if predict "just worked" with vectorised random variables. Given the failure mode of naive usage of predict I'm assuming that this is a problem with how the random numbers are generated around here?

SamuelBrand1 avatar May 28 '24 14:05 SamuelBrand1

Okay, so the fact that any of this works is not great :sweat_smile:

A few immediate things:

  1. None of these scenarios are meant to be supported :confused: The fact that the fixed version works is just a happy accident for this particular model.
  2. One crucial aspect to Turing.jl is that the variables are treated as they occur in the model. This means that if x is sampled from a multivariate distribution, well, then it is treated as a single multivariate random variable. Any attempts to treat it otherwise, are generally not supported. In this sense, we rely on the user to tell us the correct "semantics" of / how to interpret a given variable. It's also the case that in general, it's not possible to marginalize out components (as we would technically have to do in this scenario, since the desired behavior would be to fix mu[1:10] and only sample mu[11:20]).

The reason why the 2nd scenario works at all is because the prior distribution for μ[11] is the same as the posterior predictive distribution for μ[11]. If μ[11] instead was dependent on some other variable, e.g. σ, the "fixed" version would result in μ[11] being sampled from a MvNormal with σ from the prior (not the posterior / chain).

In short, Turing.jl executes the model once before running the main part of the predict code, and uses the resulting "trace" / dictionary-like structure as a template for subsequent predictions. This initial run to construct this "trace" samples from the prior. This is why the first version results in μ[11:20] being "frozen"; one "trace" was sampled from the prior at the beginning, and their values are never resampled! Similarly, the reason why the other version happen to work, is because every time you call predict, a new sample from the prior is used to produce the "trace" before running the "actual" predict; hence, the values for μ[11:20] are sampled from the prior in every call to predict but never resampled once we've set μ[1:10] :confused:

We have some checks in place to warn the user about these scenarios, but we clearly need more since this slipped through the cracks!

Effectively, if you want variables to be treated as i.i.d. rather than as a single multivariate, then you should either use ~.~ or~ for loop (as you did in the second scenario). For example, the following works:

# Nevermind; this also doesn't work...
@model function mv_normal_3(n)
	σ ~ truncated(Normal(0., 0.1), lower = 0.)
	μ = Vector{eltype(σ)}(undef, n)
	μ .~ Normal()
	x ~ MvNormal(μ, σ) # noise
	return x
end

mdl3_10 = mv_normal_3(10)
x_data3 = mdl3_10()
chn3 = sample(mdl3_10 | (x = x_data3,), NUTS(), 2_000)


forecast_mdl3 = mv_normal_3(20)
forecast_chn3 = predict(forecast_mdl3, chn3; include_all = true)

EDIT: Nvm, .~ also doesn't work I just realized, which is annoying because semantically speaking, it should :confused: Hmm, might want to do something about that.

torfjelde avatar Jun 03 '24 08:06 torfjelde

Thanks for the detailed explanation!

So it turns out that this only works for models which have a representation where the priors are the same as the posterior.

TBF, this is actually a pretty large class of forecast models: discrete-time numerical solutions to SDEs and the finite dimensional distributions of a GP can be written this way (we were motivated by having a standard parameterisation of latent white noise).

SamuelBrand1 avatar Jun 03 '24 09:06 SamuelBrand1

So to be clear, in the example where you fix part of an array here using the Dict form of fixing... that wouldn't work if the array had been declared by calling ~ MvNormal(...)?

SamuelBrand1 avatar Jun 03 '24 09:06 SamuelBrand1

TBF, this is actually a pretty large class of forecast models: discrete-time numerical solutions to SDEs and the finite dimensional distributions of a GP can be written this way (we were motivated by having a standard parameterisation of latent white noise).

Yeah, I definitively see the use for this! But it's somewhat non-trivial to support, so it really comes down to whether we want the maintenance burden of the functionality vs. having the user do some manual labour, i.e. use a for loop.

So to be clear, in the example where you fix part of an array here using the Dict form of fixing... that wouldn't work if the array had been declared by calling ~ MvNormal(...)?

Exactly.

The most annoying aprt of all this (IMO), is the perf implications of using vectorized vs. for loop. It's technically possible to do something like

if @performing_inference
    x ~ MvNormal(...)
else
    x = Vector(undef, 10)
    for i in eachindex(x)
        x[i] ~ Normal(...)
    end
end

but we don't have that implemented (related: https://github.com/TuringLang/DynamicPPL.jl/issues/510).

torfjelde avatar Jun 03 '24 09:06 torfjelde

One simple approach that could also work is for the aformentioned code to be generated automatically through an iid macro or something, e.g.

@iid x ~ MvNormal(...)

and then this just converts to the above code block under the hood.

torfjelde avatar Jun 03 '24 09:06 torfjelde

Right. And this would avoid issues with (say) adtype = AutoReverseDiff(true) because grad calls only occur in "inference mode" so the existence of the slower performance branch wouldn't be relevant to sampling?

SamuelBrand1 avatar Jun 03 '24 10:06 SamuelBrand1

Right. And this would avoid issues with (say) adtype = AutoReverseDiff(true) because grad calls only occur in "inference mode" so the existence of the slower performance branch wouldn't be relevant to sampling?

Exactly:)

torfjelde avatar Jun 03 '24 12:06 torfjelde

TBF, this is actually a pretty large class of forecast models: discrete-time numerical solutions to SDEs and the finite dimensional distributions of a GP can be written this way (we were motivated by having a standard parameterisation of latent white noise).

Yeah, I definitively see the use for this! But it's somewhat non-trivial to support, so it really comes down to whether we want the maintenance burden of the functionality vs. having the user do some manual labour, i.e. use a for loop.

So to be clear, in the example where you fix part of an array here using the Dict form of fixing... that wouldn't work if the array had been declared by calling ~ MvNormal(...)?

Exactly.

The most annoying aprt of all this (IMO), is the perf implications of using vectorized vs. for loop. It's technically possible to do something like

if @performing_inference
    x ~ MvNormal(...)
else
    x = Vector(undef, 10)
    for i in eachindex(x)
        x[i] ~ Normal(...)
    end
end

but we don't have that implemented (related: TuringLang/DynamicPPL.jl#510).

Did anything happen on this? @seabbs and myself are trying to make something where you can fully compose a fairly large set of Models defining different probabilistic components one might want in a fully feature epi model with @submodel.

Since, prediction is quite important here it would be a handy feature to have, but otoh I'm not inclined to go through every single model definition and put in a "forecast mode" boolean switch. Given the known behaviour of MvNormal under conditioning some indices, couldn't you support that as a special case?

SamuelBrand1 avatar Jul 12 '24 20:07 SamuelBrand1

Noting this also appears to be an issue with filldist(Normal(), n) which is Doc'd (i think) as idd so I assumed it would work initially.

seabbs avatar Jul 12 '24 20:07 seabbs

I've started adding (https://github.com/CDCgov/Rt-without-renewal/pull/369#issuecomment-2227513371) a PredictContext (+ hacking on Turing.predict) and have this switch-based approach on the way to working (some type conversion issues remain) I think (I did the inverse of the suggestion as thought it would be easier for someone with no real understanding of the contexts system). Something that is very clear is that it is pretty clunky (even with passing around the if else block in a submodel) and the @idd macro suggestion would be much preferred imo.

Given the known behaviour of MvNormal under conditioning some indices, couldn't you support that as a special case?

This seems even more ideal than the macro approach as that would be easy for a non-expert user to miss.

seabbs avatar Jul 14 '24 23:07 seabbs

Tagging @yebai to get some thoughts. Specifically on something like:

One simple approach that could also work is for the aformentioned code to be generated automatically through an iid macro or something, e.g.

@iid x ~ MvNormal(...)

and then this just converts to the above code block under the hood.

torfjelde avatar Jul 16 '24 07:07 torfjelde

I am not sure about the additional macro, but I don’t have a good alternative yet. I am happy to brainstorm more here.

Cc @mhauru @sunxd3

yebai avatar Jul 16 '24 20:07 yebai

@seabbs @SamuelBrand1, we would be happy to hear more of your thoughts on syntax design. I think it has to be robust and intuitive for long-term maintenance.

EDIT:

Noting this also appears to be an issue with filldist(Normal(), n) which is Doc'd (i think) as idd so I assumed it would work initially.

If we fix filldist (and maybe the broadcasting syntax, e.g. μ .~ Normal()) would that be sufficient for this use case?

Of course, we should provide a warning message to the user in all other unsupported cases discussed here.

yebai avatar Jul 16 '24 20:07 yebai

If we fix filldist (and maybe the broadcasting syntax, e.g. μ .~ Normal()) would that be sufficient for this use case?

Of course, we should provide a warning message to the user in all other unsupported cases discussed here.

Yes both of these would be great and cover our use case. It would also be nice to make it easier to switch modes (i.e. they specify a mode for inference and everything else manually) if users find other edge cases they want to fix but we wouldn't need that for what we are doing (at least at the moment).

seabbs avatar Jul 17 '24 21:07 seabbs

@seabbs @SamuelBrand1, we would be happy to hear more of your thoughts on syntax design. I think it has to be robust and intuitive for long-term maintenance.

EDIT:

Noting this also appears to be an issue with filldist(Normal(), n) which is Doc'd (i think) as idd so I assumed it would work initially.

If we fix filldist (and maybe the broadcasting syntax, e.g. μ .~ Normal()) would that be sufficient for this use case?

Of course, we should provide a warning message to the user in all other unsupported cases discussed here.

This sounds good, but would there be a performance implication? I'm wondering about the upsides/downsides here.

SamuelBrand1 avatar Jul 18 '24 08:07 SamuelBrand1

Noting this also appears to be an issue with filldist(Normal(), n) which is Doc'd (i think) as idd so I assumed it would work initially.

IMO this is more of a doc-issue, as there are definitively scenarios where you it makes sense to use filldist rather than arraydist, as the former is more efficient (when applicable).

μ .~ Normal()) would that be sufficient for this use case?

It's somewhat unclear to me how this would work, but maybe this is something we should discuss in DynamicPPL.jl :)

torfjelde avatar Jul 18 '24 08:07 torfjelde

IMO this is more of a doc-issue, as there are definitively scenarios where you it makes sense to use filldist rather than arraydist, as the former is more efficient (when applicable).

My point was that filldist also doesn't work with predict at the moment.

It's somewhat unclear to me how this would work, but maybe this is something we should discuss in DynamicPPL.jl :)

What are the steps forward from here? In principle happy to help out with a proposed fixed but likely to need guidance / its unclear how y'all manage community contributions.

It would be really great for us to have this fixed centrally vs implementing modifications in our tooling to get this working.

seabbs avatar Jul 25 '24 15:07 seabbs

Ah I see discussion over on https://github.com/TuringLang/DynamicPPL.jl/issues/510 so I guess I'll track that!

If we fix filldist (and maybe the broadcasting syntax, e.g. μ .~ Normal()) would that be sufficient for this use case?

This resolution does seem out of the scope of that issues however.

seabbs avatar Jul 25 '24 15:07 seabbs