Kernel DSL error with if statement
I find I get not defined errors when using the kernel DSL to choose which kernel to use depending on a bernoulli sample. As an example:
julia> @kern function switcher(tr)
let do_x ~ bernoulli(0.5)
if do_x
tr ~ mh(tr, select(:x))
else
tr ~ mh(tr, select(:y))
end
end
end
julia> switcher(tr)
ERROR: UndefVarError: do_x not defined
I also think the documentation could make it more clear that this should be a valid MH move; currently it seems to suggest we can't use random variables in if statements.
Also--Is there some reason support for this has not been implemented? Would allowing this syntax to work make it possible to write buggy kernels? If so, should we introduce new syntax to handle the valid cases of swapping kernels based on a random sample?
(copied from what Marco posted in the slack:) Actually --- if-else-end is not implemented yet. I only implemented if-end. This code runs successfully for me: using Gen
@gen function foo()
x ~ normal(0, 1)
y ~ normal(0, 1)
end
@kern function my_kernel(tr)
let use_kern_1 ~ bernoulli(0.5)
if use_kern_1
tr ~ mh(tr, select(:x))
end
if !use_kern_1
tr ~ mh(tr, select(:y))
end
end
end
trace = simulate(foo, ())
display(get_choices(trace))
trace, metadata = my_kernel(trace)
display(get_choices(trace))
Of course, there should be a syntax error at minimum for using the not-yet-implemented syntax.
Also the doc should be clearer that else is not yet supported. Currently it does not make that very clear:
If-end expressions The predicate condition may be a deterministic function of the trace, but it also must be invariant (i.e. remain true) under all possible executions of the body.
A similar issue arises in supported language constructs, when using kernel arguments.
For instance:
@gen function model()
x ~ normal(0, 1)
y ~ normal(0, 1)
z ~ normal(0, 1)
end
@kern function drift_kern(trace, addr, width, num_steps)
for _ in 1:num_steps
trace ~ mh(trace, drift_proposal, (addr, width))
end
end
@kern function repeated_gibbs_kern(trace)
trace ~ drift_kern(trace, :x, 0.01, 100)
trace ~ drift_kern(trace, :y, 0.01, 100)
trace ~ drift_kern(trace, :z, 0.01, 100)
end
trace, _ = Gen.generate(model, (),)
trace = repeated_gibbs_kern(trace)
Results in an undefined error for num_steps, addr, and width (depending on which the code runs into first).