Jacob Hackett
Jacob Hackett
I figured I'd post this as I was running into the same issue. For reference this is what my example.cc (the same as the tutorial) and BUILD files look like:...
The goal is to make the solver reverse mode autodiff compatible when `m.opt.iterations > 1`. In this specific context, setting `m.opt.tolerance = 0` to achieve this behavior feels a bit...
You bring up a good point; I don't have good intuition on the performance differences between `jax.lax.while_loop` and `jax.lax.scan`. I would have to benchmark it. I would still find it...
@erikfrey That makes sense, thank you for the clarification! I think @varshneydevansh has a PR with that implementation. Looking forward to the fix, thanks, guys!