Gen.jl icon indicating copy to clipboard operation
Gen.jl copied to clipboard

Kernel DSL error with if statement

Open georgematheos opened this issue 5 years ago • 2 comments

I find I get not defined errors when using the kernel DSL to choose which kernel to use depending on a bernoulli sample. As an example:

julia> @kern function switcher(tr)
           let do_x ~ bernoulli(0.5)
               if do_x
                   tr ~ mh(tr, select(:x))
               else
                   tr ~ mh(tr, select(:y))
               end
           end
       end

julia> switcher(tr)
ERROR: UndefVarError: do_x not defined

I also think the documentation could make it more clear that this should be a valid MH move; currently it seems to suggest we can't use random variables in if statements.

Also--Is there some reason support for this has not been implemented? Would allowing this syntax to work make it possible to write buggy kernels? If so, should we introduce new syntax to handle the valid cases of swapping kernels based on a random sample?

georgematheos avatar Aug 13 '20 18:08 georgematheos

(copied from what Marco posted in the slack:) Actually --- if-else-end is not implemented yet. I only implemented if-end. This code runs successfully for me: using Gen

@gen function foo()
    x ~ normal(0, 1)
    y ~ normal(0, 1)
end
@kern function my_kernel(tr)
  let use_kern_1 ~ bernoulli(0.5)
    if use_kern_1
      tr ~ mh(tr, select(:x))
    end
    if !use_kern_1
      tr ~ mh(tr, select(:y))
    end
  end
end
trace = simulate(foo, ())
display(get_choices(trace))
trace, metadata = my_kernel(trace)
display(get_choices(trace))

Of course, there should be a syntax error at minimum for using the not-yet-implemented syntax.

Also the doc should be clearer that else is not yet supported. Currently it does not make that very clear:

If-end expressions The predicate condition may be a deterministic function of the trace, but it also must be invariant (i.e. remain true) under all possible executions of the body.

georgematheos avatar Aug 13 '20 18:08 georgematheos

A similar issue arises in supported language constructs, when using kernel arguments.

For instance:

@gen function model()
    x ~ normal(0, 1)
    y ~ normal(0, 1)
    z ~ normal(0, 1)
end

@kern function drift_kern(trace, addr, width, num_steps)
    for _ in 1:num_steps
        trace ~ mh(trace, drift_proposal, (addr, width))
    end
end

@kern function repeated_gibbs_kern(trace)
    trace ~ drift_kern(trace, :x, 0.01, 100)
    trace ~ drift_kern(trace, :y, 0.01, 100)
    trace ~ drift_kern(trace, :z, 0.01, 100)
end

trace, _ = Gen.generate(model, (),)
trace = repeated_gibbs_kern(trace)

Results in an undefined error for num_steps, addr, and width (depending on which the code runs into first).

agarret7 avatar Jun 24 '21 22:06 agarret7