Libtask.jl icon indicating copy to clipboard operation
Libtask.jl copied to clipboard

Support for multiple nested `produce`

Open torfjelde opened this issue 2 years ago • 1 comments

As seen discussed extensively in https://github.com/TuringLang/Turing.jl/pull/2001, in particular https://github.com/TuringLang/Turing.jl/pull/2001#issuecomment-1585694587, Libtask.jl makes one crucial assumption: every Instruction contains at most 1 produce statement.

This is because https://github.com/TuringLang/Libtask.jl/blob/95e32aa525be3649d0671ce3e47efb6e38382421/src/tapedfunction.jl#L73-L74 where https://github.com/TuringLang/Libtask.jl/blob/95e32aa525be3649d0671ce3e47efb6e38382421/src/tapedfunction.jl#L44-L48 which is then traversed to construct the tape.

There are many cases in which this is just not true in Turing.jl, e.g. when we use @submodel.

Moreover, it's very unclear to me how this can be addressed without doing something very fancy to allow us to recurse into the type-inference that is performed.

EDIT: Here's an example of what I mean:

julia> using Libtask

julia> f(x) = (produce(x); produce(2x); produce(3x); return nothing)
f (generic function with 1 method)

julia> g(x) = f(x)
g (generic function with 1 method)

julia> task = Libtask.TapedTask(f, 1);

julia> consume(task), consume(task), consume(task)
(1, 2, 3)

julia> task = Libtask.TapedTask(g, 1);  # tracing of nested call

julia> consume(task)   # goes through all the `produce` calls before even calling the `callback` (which is `Libtask.producer`)
counter=1
tf=TapedFunction:
* .func => g
* .ir   =>
------------------
CodeInfo(
1 ─ %1 = Main.f(x)::Core.Const(nothing)
└──      return %1
)
------------------

ErrorException("There is a produced value which is not consumed.")Union{Ptr{Nothing}, Base.InterpreterIP}[Ptr{Nothing} @0x00007fa8d200eeeb, Ptr{Nothing} @0x00007fa8a0a30f29, Ptr{Nothing} @0x00007fa8a0a36844, Ptr{Nothing} @0x00007fa8a0a36865, Ptr{Nothing} @0x00007fa8e6b44f1d, Ptr{Nothing} @0x00007fa8a0a366e3, Ptr{Nothing} @0x00007fa8a0a36802, Ptr{Nothing} @0x00007fa8e6b44f1d, Ptr{Nothing} @0x00007fa8a0a35f25, Ptr{Nothing} @0x00007fa8a0a361dd, Ptr{Nothing} @0x00007fa8a0a36512, Ptr{Nothing} @0x00007fa8a0a3652f, Ptr{Nothing} @0x00007fa8e6b44f1d, Ptr{Nothing} @0x00007fa8e6b6656f]
ERROR: There is a produced value which is not consumed.
Stacktrace:
 [1] consume(ttask::TapedTask{typeof(g), Tuple{Int64}})
   @ Libtask ~/.julia/packages/Libtask/h7Kal/src/tapedtask.jl:153
 [2] top-level scope
   @ REPL[9]:1

torfjelde avatar Jun 11 '23 14:06 torfjelde

One candidate to address this would be to replace Libtask._infer with last(Umlaut.trace) and then build the instruction tape from the resulting trace.

For example:


julia> using DynamicPPL, Distributions, Umlaut

julia> struct DynamicPPLModelCtx{F}
           model::Model{F}
       end

julia> function isprimitive(ctx::DynamicPPLModelCtx, f, args...)
           f === ctx.model.f && return false
           if Base.parentmodule(f) == DynamicPPL
               # Trace into `DynamicPPL._evaluate!!`.
               f === DynamicPPL._evaluate!! && return false
           end
           return true
       end
isprimitive (generic function with 5 methods)

julia> @model function demo(x)
           z ~ Normal()
           x ~ Normal(z, 1)
       end
demo (generic function with 4 methods)

julia> model = demo(1);

julia> ctx = SamplingContext();

julia> varinfo = VarInfo(model);

julia> t = last(Umlaut.trace(DynamicPPL._evaluate!!, model, varinfo, ctx; ctx=DynamicPPLModelCtx(model)))
Tape{DynamicPPLModelCtx{typeof(demo)}}
  inp %1::typeof(DynamicPPL._evaluate!!)
  inp %2::Model{typeof(demo), (:x,), (), (), Tuple{Int64}, Tuple{}, DefaultContext}
  inp %3::TypedVarInfo{NamedTuple{(:z,), Tuple{DynamicPPL.Metadata{Dict{VarName{:z, Setfield.IdentityLens}, Int64}, Vector{Normal{Float64}}, Vector{VarName{:z, Setfield.IdentityLens}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}}}, Float64}
  inp %4::SamplingContext{SampleFromPrior, DefaultContext, Random.TaskLocalRNG}
  %5 = make_evaluate_args_and_kwargs(%2, %3, %4)::Tuple{Tuple{Model{typeof(demo), (:x,), (), (), Tuple{Int64}, Tuple{}, DefaultContext}, TypedVarInfo{NamedTuple{(:z,), Tuple{DynamicPPL.Metadata{Dict{VarName{:z, Setfield.IdentityLens}, Int64}, Vector{Normal{Float64}}, Vector{VarName{:z, Setfield.IdentityLens}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}}}, Float64}, SamplingContext{SampleFromPrior, DefaultContext, Random.TaskLocalRNG}, Int64}, NamedTuple{(), Tuple{}}} 
  %6 = indexed_iterate(%5, 1)::Tuple{Tuple{Model{typeof(demo), (:x,), (), (), Tuple{Int64}, Tuple{}, DefaultContext}, TypedVarInfo{NamedTuple{(:z,), Tuple{DynamicPPL.Metadata{Dict{VarName{:z, Setfield.IdentityLens}, Int64}, Vector{Normal{Float64}}, Vector{VarName{:z, Setfield.IdentityLens}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}}}, Float64}, SamplingContext{SampleFromPrior, DefaultContext, Random.TaskLocalRNG}, Int64}, Int64} 
  %7 = getfield(%6, 1)::Tuple{Model{typeof(demo), (:x,), (), (), Tuple{Int64}, Tuple{}, DefaultContext}, TypedVarInfo{NamedTuple{(:z,), Tuple{DynamicPPL.Metadata{Dict{VarName{:z, Setfield.IdentityLens}, Int64}, Vector{Normal{Float64}}, Vector{VarName{:z, Setfield.IdentityLens}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}}}, Float64}, SamplingContext{SampleFromPrior, DefaultContext, Random.TaskLocalRNG}, Int64} 
  %8 = getfield(%6, 2)::Int64 
  %9 = indexed_iterate(%5, 2, %8)::Tuple{NamedTuple{(), Tuple{}}, Int64} 
  %10 = getfield(%9, 1)::NamedTuple{(), Tuple{}} 
  %11 = NamedTuple()::NamedTuple{(), Tuple{}} 
  %12 = merge(%11, %10)::NamedTuple{(), Tuple{}} 
  %13 = isempty(%12)::Bool 
  %14 = getproperty(%2, :f)::typeof(demo) 
  %15 = check_variable_length(%7, 4, 7)::Nothing 
  %16 = getindex(%7, 1)::Model{typeof(demo), (:x,), (), (), Tuple{Int64}, Tuple{}, DefaultContext} 
  %17 = getindex(%7, 2)::TypedVarInfo{NamedTuple{(:z,), Tuple{DynamicPPL.Metadata{Dict{VarName{:z, Setfield.IdentityLens}, Int64}, Vector{Normal{Float64}}, Vector{VarName{:z, Setfield.IdentityLens}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}}}, Float64} 
  %18 = getindex(%7, 3)::SamplingContext{SampleFromPrior, DefaultContext, Random.TaskLocalRNG} 
  %19 = getindex(%7, 4)::Int64 
  const %20 = nothing::Nothing
  const %21 = nothing::Nothing
  const %22 = nothing::Nothing
  const %23 = nothing::Nothing
  const %24 = nothing::Nothing
  const %25 = nothing::Nothing
  const %26 = nothing::Nothing
  const %27 = nothing::Nothing
  const %28 = nothing::Nothing
  const %29 = nothing::Nothing
  const %30 = nothing::Nothing
  const %31 = nothing::Nothing
  const %32 = nothing::Nothing
  const %33 = nothing::Nothing
  %34 = Normal()::Normal{Float64} 
  %35 = apply_type(VarName, :z)::UnionAll 
  %36 = %35()::VarName{:z, Setfield.IdentityLens} 
  %37 = resolve_varnames(%36, %34)::VarName{:z, Setfield.IdentityLens} 
  %38 = contextual_isassumption(%18, %37)::Bool 
  %39 = inargnames(%37, %16)::Bool 
  %40 = !(%39)::Bool 
  const %41 = true::Bool
  const %42 = nothing::Nothing
  %43 = tuple(%18)::Tuple{SamplingContext{SampleFromPrior, DefaultContext, Random.TaskLocalRNG}} 
  %44 = check_tilde_rhs(%34)::Normal{Float64} 
  %45 = unwrap_right_vn(%44, %37)::Tuple{Normal{Float64}, VarName{:z, Setfield.IdentityLens}} 
  %46 = tuple(%17)::Tuple{TypedVarInfo{NamedTuple{(:z,), Tuple{DynamicPPL.Metadata{Dict{VarName{:z, Setfield.IdentityLens}, Int64}, Vector{Normal{Float64}}, Vector{VarName{:z, Setfield.IdentityLens}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}}}, Float64}} 
  %47 = check_variable_length(%43, 1, 43)::Nothing 
  %48 = check_variable_length(%45, 2, 45)::Nothing 
  %49 = getindex(%45, 1)::Normal{Float64} 
  %50 = getindex(%45, 2)::VarName{:z, Setfield.IdentityLens} 
  %51 = check_variable_length(%46, 1, 46)::Nothing 
  %52 = tilde_assume!!(%18, %49, %50, %17)::Tuple{Float64, TypedVarInfo{NamedTuple{(:z,), Tuple{DynamicPPL.Metadata{Dict{VarName{:z, Setfield.IdentityLens}, Int64}, Vector{Normal{Float64}}, Vector{VarName{:z, Setfield.IdentityLens}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}}}, Float64}} 
  %53 = indexed_iterate(%52, 1)::Tuple{Float64, Int64} 
  %54 = getfield(%53, 1)::Float64 
  %55 = getfield(%53, 2)::Int64 
  %56 = indexed_iterate(%52, 2, %55)::Tuple{TypedVarInfo{NamedTuple{(:z,), Tuple{DynamicPPL.Metadata{Dict{VarName{:z, Setfield.IdentityLens}, Int64}, Vector{Normal{Float64}}, Vector{VarName{:z, Setfield.IdentityLens}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}}}, Float64}, Int64} 
  %57 = getfield(%56, 1)::TypedVarInfo{NamedTuple{(:z,), Tuple{DynamicPPL.Metadata{Dict{VarName{:z, Setfield.IdentityLens}, Int64}, Vector{Normal{Float64}}, Vector{VarName{:z, Setfield.IdentityLens}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}}}, Float64} 
  %58 = Normal(%54, 1)::Normal{Float64} 
  %59 = apply_type(VarName, :x)::UnionAll 
  %60 = %59()::VarName{:x, Setfield.IdentityLens} 
  %61 = resolve_varnames(%60, %58)::VarName{:x, Setfield.IdentityLens} 
  %62 = contextual_isassumption(%18, %61)::Bool 
  %63 = inargnames(%61, %16)::Bool 
  %64 = !(%63)::Bool 
  %65 = inmissings(%61, %16)::Bool 
  %66 = ===(%19, missing)::Bool 
  %67 = inargnames(%61, %16)::Bool 
  %68 = !(%67)::Bool 
  %69 = check_tilde_rhs(%58)::Normal{Float64} 
  %70 = tilde_observe!!(%18, %69, %19, %61, %57)::Tuple{Int64, TypedVarInfo{NamedTuple{(:z,), Tuple{DynamicPPL.Metadata{Dict{VarName{:z, Setfield.IdentityLens}, Int64}, Vector{Normal{Float64}}, Vector{VarName{:z, Setfield.IdentityLens}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}}}, Float64}} 
  %71 = indexed_iterate(%70, 1)::Tuple{Int64, Int64} 
  %72 = getfield(%71, 1)::Int64 
  %73 = getfield(%71, 2)::Int64 
  %74 = indexed_iterate(%70, 2, %73)::Tuple{TypedVarInfo{NamedTuple{(:z,), Tuple{DynamicPPL.Metadata{Dict{VarName{:z, Setfield.IdentityLens}, Int64}, Vector{Normal{Float64}}, Vector{VarName{:z, Setfield.IdentityLens}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}}}, Float64}, Int64} 
  %75 = getfield(%74, 1)::TypedVarInfo{NamedTuple{(:z,), Tuple{DynamicPPL.Metadata{Dict{VarName{:z, Setfield.IdentityLens}, Int64}, Vector{Normal{Float64}}, Vector{VarName{:z, Setfield.IdentityLens}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}}}, Float64} 
  %76 = tuple(%72, %75)::Tuple{Int64, TypedVarInfo{NamedTuple{(:z,), Tuple{DynamicPPL.Metadata{Dict{VarName{:z, Setfield.IdentityLens}, Int64}, Vector{Normal{Float64}}, Vector{VarName{:z, Setfield.IdentityLens}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}}}, Float64}} 

This should even be possible to use with Libtask.jl with some minor changes by just compiling the tape t and then pass that to Libtask._infer_.

Note the isprimitive would have to be fine-tuned to also support usage of @submodel, but it could easily be done.

AFAIK the main drawback by making this change is that we drop support for control-flow (unless we want to perform trace on every call).

torfjelde avatar Jun 11 '23 14:06 torfjelde