Support for multiple nested `produce`
As seen discussed extensively in https://github.com/TuringLang/Turing.jl/pull/2001, in particular https://github.com/TuringLang/Turing.jl/pull/2001#issuecomment-1585694587, Libtask.jl makes one crucial assumption: every Instruction contains at most 1 produce statement.
This is because https://github.com/TuringLang/Libtask.jl/blob/95e32aa525be3649d0671ce3e47efb6e38382421/src/tapedfunction.jl#L73-L74 where https://github.com/TuringLang/Libtask.jl/blob/95e32aa525be3649d0671ce3e47efb6e38382421/src/tapedfunction.jl#L44-L48 which is then traversed to construct the tape.
There are many cases in which this is just not true in Turing.jl, e.g. when we use @submodel.
Moreover, it's very unclear to me how this can be addressed without doing something very fancy to allow us to recurse into the type-inference that is performed.
EDIT: Here's an example of what I mean:
julia> using Libtask
julia> f(x) = (produce(x); produce(2x); produce(3x); return nothing)
f (generic function with 1 method)
julia> g(x) = f(x)
g (generic function with 1 method)
julia> task = Libtask.TapedTask(f, 1);
julia> consume(task), consume(task), consume(task)
(1, 2, 3)
julia> task = Libtask.TapedTask(g, 1); # tracing of nested call
julia> consume(task) # goes through all the `produce` calls before even calling the `callback` (which is `Libtask.producer`)
counter=1
tf=TapedFunction:
* .func => g
* .ir =>
------------------
CodeInfo(
1 ─ %1 = Main.f(x)::Core.Const(nothing)
└── return %1
)
------------------
ErrorException("There is a produced value which is not consumed.")Union{Ptr{Nothing}, Base.InterpreterIP}[Ptr{Nothing} @0x00007fa8d200eeeb, Ptr{Nothing} @0x00007fa8a0a30f29, Ptr{Nothing} @0x00007fa8a0a36844, Ptr{Nothing} @0x00007fa8a0a36865, Ptr{Nothing} @0x00007fa8e6b44f1d, Ptr{Nothing} @0x00007fa8a0a366e3, Ptr{Nothing} @0x00007fa8a0a36802, Ptr{Nothing} @0x00007fa8e6b44f1d, Ptr{Nothing} @0x00007fa8a0a35f25, Ptr{Nothing} @0x00007fa8a0a361dd, Ptr{Nothing} @0x00007fa8a0a36512, Ptr{Nothing} @0x00007fa8a0a3652f, Ptr{Nothing} @0x00007fa8e6b44f1d, Ptr{Nothing} @0x00007fa8e6b6656f]
ERROR: There is a produced value which is not consumed.
Stacktrace:
[1] consume(ttask::TapedTask{typeof(g), Tuple{Int64}})
@ Libtask ~/.julia/packages/Libtask/h7Kal/src/tapedtask.jl:153
[2] top-level scope
@ REPL[9]:1
One candidate to address this would be to replace Libtask._infer with last(Umlaut.trace) and then build the instruction tape from the resulting trace.
For example:
julia> using DynamicPPL, Distributions, Umlaut
julia> struct DynamicPPLModelCtx{F}
model::Model{F}
end
julia> function isprimitive(ctx::DynamicPPLModelCtx, f, args...)
f === ctx.model.f && return false
if Base.parentmodule(f) == DynamicPPL
# Trace into `DynamicPPL._evaluate!!`.
f === DynamicPPL._evaluate!! && return false
end
return true
end
isprimitive (generic function with 5 methods)
julia> @model function demo(x)
z ~ Normal()
x ~ Normal(z, 1)
end
demo (generic function with 4 methods)
julia> model = demo(1);
julia> ctx = SamplingContext();
julia> varinfo = VarInfo(model);
julia> t = last(Umlaut.trace(DynamicPPL._evaluate!!, model, varinfo, ctx; ctx=DynamicPPLModelCtx(model)))
Tape{DynamicPPLModelCtx{typeof(demo)}}
inp %1::typeof(DynamicPPL._evaluate!!)
inp %2::Model{typeof(demo), (:x,), (), (), Tuple{Int64}, Tuple{}, DefaultContext}
inp %3::TypedVarInfo{NamedTuple{(:z,), Tuple{DynamicPPL.Metadata{Dict{VarName{:z, Setfield.IdentityLens}, Int64}, Vector{Normal{Float64}}, Vector{VarName{:z, Setfield.IdentityLens}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}}}, Float64}
inp %4::SamplingContext{SampleFromPrior, DefaultContext, Random.TaskLocalRNG}
%5 = make_evaluate_args_and_kwargs(%2, %3, %4)::Tuple{Tuple{Model{typeof(demo), (:x,), (), (), Tuple{Int64}, Tuple{}, DefaultContext}, TypedVarInfo{NamedTuple{(:z,), Tuple{DynamicPPL.Metadata{Dict{VarName{:z, Setfield.IdentityLens}, Int64}, Vector{Normal{Float64}}, Vector{VarName{:z, Setfield.IdentityLens}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}}}, Float64}, SamplingContext{SampleFromPrior, DefaultContext, Random.TaskLocalRNG}, Int64}, NamedTuple{(), Tuple{}}}
%6 = indexed_iterate(%5, 1)::Tuple{Tuple{Model{typeof(demo), (:x,), (), (), Tuple{Int64}, Tuple{}, DefaultContext}, TypedVarInfo{NamedTuple{(:z,), Tuple{DynamicPPL.Metadata{Dict{VarName{:z, Setfield.IdentityLens}, Int64}, Vector{Normal{Float64}}, Vector{VarName{:z, Setfield.IdentityLens}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}}}, Float64}, SamplingContext{SampleFromPrior, DefaultContext, Random.TaskLocalRNG}, Int64}, Int64}
%7 = getfield(%6, 1)::Tuple{Model{typeof(demo), (:x,), (), (), Tuple{Int64}, Tuple{}, DefaultContext}, TypedVarInfo{NamedTuple{(:z,), Tuple{DynamicPPL.Metadata{Dict{VarName{:z, Setfield.IdentityLens}, Int64}, Vector{Normal{Float64}}, Vector{VarName{:z, Setfield.IdentityLens}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}}}, Float64}, SamplingContext{SampleFromPrior, DefaultContext, Random.TaskLocalRNG}, Int64}
%8 = getfield(%6, 2)::Int64
%9 = indexed_iterate(%5, 2, %8)::Tuple{NamedTuple{(), Tuple{}}, Int64}
%10 = getfield(%9, 1)::NamedTuple{(), Tuple{}}
%11 = NamedTuple()::NamedTuple{(), Tuple{}}
%12 = merge(%11, %10)::NamedTuple{(), Tuple{}}
%13 = isempty(%12)::Bool
%14 = getproperty(%2, :f)::typeof(demo)
%15 = check_variable_length(%7, 4, 7)::Nothing
%16 = getindex(%7, 1)::Model{typeof(demo), (:x,), (), (), Tuple{Int64}, Tuple{}, DefaultContext}
%17 = getindex(%7, 2)::TypedVarInfo{NamedTuple{(:z,), Tuple{DynamicPPL.Metadata{Dict{VarName{:z, Setfield.IdentityLens}, Int64}, Vector{Normal{Float64}}, Vector{VarName{:z, Setfield.IdentityLens}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}}}, Float64}
%18 = getindex(%7, 3)::SamplingContext{SampleFromPrior, DefaultContext, Random.TaskLocalRNG}
%19 = getindex(%7, 4)::Int64
const %20 = nothing::Nothing
const %21 = nothing::Nothing
const %22 = nothing::Nothing
const %23 = nothing::Nothing
const %24 = nothing::Nothing
const %25 = nothing::Nothing
const %26 = nothing::Nothing
const %27 = nothing::Nothing
const %28 = nothing::Nothing
const %29 = nothing::Nothing
const %30 = nothing::Nothing
const %31 = nothing::Nothing
const %32 = nothing::Nothing
const %33 = nothing::Nothing
%34 = Normal()::Normal{Float64}
%35 = apply_type(VarName, :z)::UnionAll
%36 = %35()::VarName{:z, Setfield.IdentityLens}
%37 = resolve_varnames(%36, %34)::VarName{:z, Setfield.IdentityLens}
%38 = contextual_isassumption(%18, %37)::Bool
%39 = inargnames(%37, %16)::Bool
%40 = !(%39)::Bool
const %41 = true::Bool
const %42 = nothing::Nothing
%43 = tuple(%18)::Tuple{SamplingContext{SampleFromPrior, DefaultContext, Random.TaskLocalRNG}}
%44 = check_tilde_rhs(%34)::Normal{Float64}
%45 = unwrap_right_vn(%44, %37)::Tuple{Normal{Float64}, VarName{:z, Setfield.IdentityLens}}
%46 = tuple(%17)::Tuple{TypedVarInfo{NamedTuple{(:z,), Tuple{DynamicPPL.Metadata{Dict{VarName{:z, Setfield.IdentityLens}, Int64}, Vector{Normal{Float64}}, Vector{VarName{:z, Setfield.IdentityLens}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}}}, Float64}}
%47 = check_variable_length(%43, 1, 43)::Nothing
%48 = check_variable_length(%45, 2, 45)::Nothing
%49 = getindex(%45, 1)::Normal{Float64}
%50 = getindex(%45, 2)::VarName{:z, Setfield.IdentityLens}
%51 = check_variable_length(%46, 1, 46)::Nothing
%52 = tilde_assume!!(%18, %49, %50, %17)::Tuple{Float64, TypedVarInfo{NamedTuple{(:z,), Tuple{DynamicPPL.Metadata{Dict{VarName{:z, Setfield.IdentityLens}, Int64}, Vector{Normal{Float64}}, Vector{VarName{:z, Setfield.IdentityLens}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}}}, Float64}}
%53 = indexed_iterate(%52, 1)::Tuple{Float64, Int64}
%54 = getfield(%53, 1)::Float64
%55 = getfield(%53, 2)::Int64
%56 = indexed_iterate(%52, 2, %55)::Tuple{TypedVarInfo{NamedTuple{(:z,), Tuple{DynamicPPL.Metadata{Dict{VarName{:z, Setfield.IdentityLens}, Int64}, Vector{Normal{Float64}}, Vector{VarName{:z, Setfield.IdentityLens}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}}}, Float64}, Int64}
%57 = getfield(%56, 1)::TypedVarInfo{NamedTuple{(:z,), Tuple{DynamicPPL.Metadata{Dict{VarName{:z, Setfield.IdentityLens}, Int64}, Vector{Normal{Float64}}, Vector{VarName{:z, Setfield.IdentityLens}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}}}, Float64}
%58 = Normal(%54, 1)::Normal{Float64}
%59 = apply_type(VarName, :x)::UnionAll
%60 = %59()::VarName{:x, Setfield.IdentityLens}
%61 = resolve_varnames(%60, %58)::VarName{:x, Setfield.IdentityLens}
%62 = contextual_isassumption(%18, %61)::Bool
%63 = inargnames(%61, %16)::Bool
%64 = !(%63)::Bool
%65 = inmissings(%61, %16)::Bool
%66 = ===(%19, missing)::Bool
%67 = inargnames(%61, %16)::Bool
%68 = !(%67)::Bool
%69 = check_tilde_rhs(%58)::Normal{Float64}
%70 = tilde_observe!!(%18, %69, %19, %61, %57)::Tuple{Int64, TypedVarInfo{NamedTuple{(:z,), Tuple{DynamicPPL.Metadata{Dict{VarName{:z, Setfield.IdentityLens}, Int64}, Vector{Normal{Float64}}, Vector{VarName{:z, Setfield.IdentityLens}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}}}, Float64}}
%71 = indexed_iterate(%70, 1)::Tuple{Int64, Int64}
%72 = getfield(%71, 1)::Int64
%73 = getfield(%71, 2)::Int64
%74 = indexed_iterate(%70, 2, %73)::Tuple{TypedVarInfo{NamedTuple{(:z,), Tuple{DynamicPPL.Metadata{Dict{VarName{:z, Setfield.IdentityLens}, Int64}, Vector{Normal{Float64}}, Vector{VarName{:z, Setfield.IdentityLens}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}}}, Float64}, Int64}
%75 = getfield(%74, 1)::TypedVarInfo{NamedTuple{(:z,), Tuple{DynamicPPL.Metadata{Dict{VarName{:z, Setfield.IdentityLens}, Int64}, Vector{Normal{Float64}}, Vector{VarName{:z, Setfield.IdentityLens}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}}}, Float64}
%76 = tuple(%72, %75)::Tuple{Int64, TypedVarInfo{NamedTuple{(:z,), Tuple{DynamicPPL.Metadata{Dict{VarName{:z, Setfield.IdentityLens}, Int64}, Vector{Normal{Float64}}, Vector{VarName{:z, Setfield.IdentityLens}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}}}, Float64}}
This should even be possible to use with Libtask.jl with some minor changes by just compiling the tape t and then pass that to Libtask._infer_.
Note the isprimitive would have to be fine-tuned to also support usage of @submodel, but it could easily be done.
AFAIK the main drawback by making this change is that we drop support for control-flow (unless we want to perform trace on every call).