reduce_log_sum_exp function, aka reduce_sum for mixtures
I have a situation where I'd like use something like reduce_sum, but it's for marginalizing out a discrete variable, so the reduction is log-sum-exp rather than just sum.
The general pattern for marginalizing out a discrete variable is this:
vector[K] lp;
for (z in 1:K)
lp[z] = log p(z) + log p(y | z, x, theta);
target += log_sum_exp(lp);
It'd be nice to have a reduce_log_sum_exp function that works just like reduce_sum, only with log_sum_exp rather than sum as the final reduction before adding to target.
The log-sum-exp gradient itself can be done analytically as part of the reduction, which will save a whole bunch of storage relative to just using map_rect and doing the log-sum-exp on the result.
Can't you use reduce_sum_static which you use with large grainsizes? Then you do a few large log_sum_exp, but add on the natural scale (hoping that having summed already a large number of terms together that you are good with the numerics).
... but I have thought about a reduce_log_sum_exp as well myself...
Could we add an enum template to reduce_sum and reduce_sum_impl that the compiler uses to deduce the reduce function like
enum class parallel_reducers {sum, log_sum_exp, ...};
template <typename ReduceFunction, parallel_reducers Reduction, typename Vec,
typename = require_vector_like_t<Vec>, typename... Args>
inline auto reduce_sum(Vec&& vmapped, int grainsize, std::ostream* msgs,
Args&&... args) {
// ...
return internal::reduce_sum_impl<ReduceFunction, Reduction, void, return_type, Vec,
ref_type_t<Args&&>...>()(...)
// ...
}
Then I think we could have seperate join() methods for each enum in reducer. So then when the compiler sees reduce_log_sum_exp it can print out in the c++
blah x = reduce_sum<TheReduceFunc, parallel_reducers::log_sum_exp>(...)
I think with that it would be easy to add a whole bunch of different types of reductions
Dont even really need multiple join methods, just a function in join that handles how to aggregate when the reducer changes
Hi, I have exactly the same use case as @bob-carpenter's example, so add a +1 to this idea.
In the meantime, as a stan novice, it would be great to have some template code for how to hack around this with what's currently available. If that's easy to provide, it would be great.
My guess is it would it be something like:
- define function to calculate sum of non-log likelihoods for a chunk
- chunk
1:K - use
reduce_sumto get sum of non-log likelihoods across everything - take
log
A couple of Qs:
- By "good with the numerics" I guess @wds15 means that when you're taking the sum across a large number of likelihoods, this will be dominated by the largest, and so you don't have to worry about losing numerical precision with the smaller values?
- Why were you recommending
reduce_sum_staticoverreduce_sum? Are there circumstances where that's more important?
Thanks!
You use the static thing to have the chunk sizes being well controlled. Then you run things in large chunks and use log sum exp over the smaller chunks. The static reduce sum call should return a non log value which gets accumulated by the static reduce sum. This should be ok as log sum exp is used over largish chunks.
@wds15 I'm gonna take a shot at this next week it should be pretty simple for users to supply their own reducer function or have some pre-made ones
If it’s easy to implement then that’s great…but maybe we start with an example Stan model demonstrating a Stan language based solution which I suggested? …getting varmat running in prod is a higher priority or getting rstan of the ground would be my take on current priorities…
is there a simple toy example to work on?
Ok, so the function I define has log_sum_exp inside, so returns a value on the log scale, then the static reduce sum just does a straight sum over that. Fine.
So I guess the most efficient way is to split into as many chunks as I have threads?
If it would be helpful, I can have a go at a stan language version of the example at the top of the thread, and you can edit/correct as appropriate.
I am in the same boat as @bob-carpenter and @wmacnair.
A reduce_log_sum_exp function would be great. But in meantime, it will be great if you can provide a working code for a simple example. @wmacnair were you able to write a custom code for your application?
If I knew how to do that easily I would've included an example. I think @wds15 is suggesting not using reduce_sum but one of our other parallelization tools, like map_rect. Then if you would return a bunch of answers, you can reduce them using log_sum_exp before returning them through the map_rect function call. That'll still leave you having to do a final log_sum_exp on the outside.
Thank @bob-carpenter for the response. I realize that the reduce_sum function won't work due to a loss in precision. I could not figure out how to implement parallel marginalization using map_rect either. I think @SteveBronder's idea to implement a new reducer function, like reduce_log_sum_exp, might be the way to go. Hopefully, such a function gets implemented.
So the sum you form becomes unstable even if you only split it into - say - 4 parts?
Yes @wds15. On the log scale, I am getting lp[n]= -21638. If I take exp(-21638), I get zero.
Using @bob-carpenter example, I was using sum_{n=1} ^ {n=n0} exp(lp[n]), but when I take exp(lp[n]) I get zero.
@nikunj410 --- the returns should be on the log scale, then you can use log-sum-exp to reduce on both the inside and again on the outside. That works algebraically, because log-sum-exp is just addition with log-scale arguments, and so it's transitive. That is,
log_sum_exp({a, b, c, d}) = log_sum_exp(log_sum_exp({a, b}), log_sum_exp({c, d}))
Ok, if underflow of large sub sums happens, then you can only shift things with some wisdom in advance to scale where this does not happen.
To be clear, what I suggested is:
log_sum_exp({a, b, c, d}) = log( sum( { exp( log_sum_exp({a, b}) ), exp(log_sum_exp({c, d}))) } )
The terms forming the partials need to be big enough to avoid underflowing.
On the outside, you can just apply log-sum-exp again and avoid the underflow problem, as I was suggesting in the previous message. Then it doesn't matter what scale the returns of the map are.
@bob-carpenter you can't apply on the outside the log_sum_exp if you want to use reduce_sum. reduce_sum will simply sum on the natural scale.... so we'd need a reduce_log_sum_exp for that to work. It's not hard to write (I think), it's just a lot of work to hammer out the tests & doc. I am happy to write the function itself, but I would call for help to copy over and adapt all the tests & docs.
Ah, now I see where our signals crossed. I was assuming they'd have to use map_rect to make this work, which is considerably clunkier because it doesn't bind arguments. I'd be in favor of adding a reduce_log_sum_exp function.
Exactly, reduce_log_sum_exp like function is better because we can pass any kind of argument very easily. There is one point of contention though, unlike map_rect that returns a vector, reduce_sum returns a scalar. This means if we have to estimate the probabilities associated with each discrete outcome, we will have to repeat the calculations in the generated quantities block.
Ideally, we want a reduce function for which it is easy to supply arguments and returns a vector of log-prob-density for each discrete outcome(lq[n]). We can then use log_sum_exp on this output vector lq[n] to update the target and generate probabilities associated with each discrete outcome in the generated quantities block
Ideally, we want a reduce function for which it is easy to supply arguments and returns a vector of log-prob-density for each discrete outcome(lq[n]).
You can try both ways, but we've usually found that repeating calculations in the generated quantities block is faster than adding a bunch of intermediate transformed parameters. That's because generated quantities execute with double arguments and no autodiff. The computational bottleneck in what you suggest is that it makes a much larger autodiff expression graph, which is typically the bottleneck in a Stan program.
Yes, that makes sense. I would be happy with basically anything
@bob-carpenter and @wds15 would it be reasonable to expect this feature could be released in the next release stan 2.31?
Not quite... work has not even started for it and reduce_sum was a huge feature to implement.