HHälvä
HHälvä
I find this one of the biggest practical issues with jax -- `Vmap`+`jit` are great but in a lot of code this also necessitates use of `cond` with them, which...
Btw does `switch` suffer from this same problem when used with `vmap`?
That's what I thought -- might be helpful to document that for `switch` similarly to what's in `cond` doc.: e.g" However, when transformed with `vmap` to operate over a batch...
This would be great to have!
That sounds tricky. Shame, it makes it difficult to implement any of the popular GP inference engines that use eigendecomp of tridiagonal matrices as a way to compute log-determinants cheaply.
> > a way to compute log-determinants cheaply. > > Perhaps I'm misunderstanding. Isn't `log det(M) = sum_i lg( eigh[i] )` so eigenvalues are sufficient? `eigh` is still expensive no?...