TransformVariables.jl icon indicating copy to clipboard operation
TransformVariables.jl copied to clipboard

Feature Idea: flat transform

Open scheidan opened this issue 3 years ago • 5 comments

It would be useful if transform would have an option, so that the result remains a flat vector:

transform(t, x, keep_flat=true)

A good use case is converting MCMC samples in MCMCChains.Chains objects:

import MCMCChains

samp = rand(1000, 5)            # we would get them this from a MCMC algorithm

# we get an array of named tuples, which is great to define the model but difficult to convert a `Chain`.
samp_trans1 = mapslices(s -> transform(t, s), samp, dims=2)
MCMCChains.Chains(samp_trans)              # fails

# with the new argument we would get an array
samp_trans2 = mapslice(s -> transform(t, s, keep_flat=true), samp, dims=2)
MCMCChains.Chains(samp_trans2)              # that would work

This seem related to #13

scheidan avatar Aug 03 '22 11:08 scheidan

What is the format of samp_trans2 here that you would expect? I am not familiar with MCMCChains.Chains.

tpapp avatar Aug 04 '22 10:08 tpapp

MCMCChains.Chain expects an Array of dimensions iterations × n_parameters (or iterations × n_parameters × n_chains).

Having a flat transform would make the construction of such an Array quite easy. We would need to be careful with the length:

t = as((a = asℝ,
        b = as(Vector, as(Real, 0, 1), 2),
        c = UnitVector(3)))

x = randn(dimension(t))  # length(x) == 5
transform(t, x)  # -> tuple
transform(t, x, keep_flat=true))  # -> vector of length(6) != dimension(t)

scheidan avatar Aug 04 '22 14:08 scheidan

Thanks, I get it. It should be relatively easy to flatten transformed values:

flatten(x::Real) = [x]
flatten(x::AbstractArray) = vec(x)
flatten(x::Tuple) = mapreduce(flatten, vcat, x)
flatten(x::NamedTuple) = mapreduce(flatten, vcat, values(x))

z = (a = 1.0, b = [2.0, 3.0], c = (d = 4.0, e = 5.0))

flatten(z)

can deal with everything TransformVariables can dish out at the moment. (The code above necessarily allocates and is quite suboptimal, in the ideal case this would be done with views like https://github.com/JuliaArrays/StackViews.jl).

Or would you prefer transforming directly to a flat vector for efficiency? I will keep this in mind for the next refactoring (which is coming up soon).

tpapp avatar Aug 06 '22 13:08 tpapp

Also, an ideal API would give column names, such as [:a, :b_1, :b_2, :c_d, :c_e] or similar.

tpapp avatar Aug 06 '22 13:08 tpapp

Getting meaningful names would be very helpful!

MCMCChains.jl has some support for names with brackets, for variables from arrays e.g. "x[1,1]", "x[1,2]" https://beta.turing.ml/MCMCChains.jl/stable/getting-started/#Groups-of-parameters

scheidan avatar Aug 08 '22 07:08 scheidan