Missing concents in doc of "jax.tree_util.tree_reduce"
Hi jax team. Thank you for your great framework.
I noticed that there is neither explaination on how to use jax.tree_util.tree_reduce() nor on its argument type. Would you please explain the usage in details? Because it occurs in the training code of JaxNeRF.
Also, I am wondering whether the doc will be completed in the near future? Not sure whether there is going to be a big update and those functions will be deleted in the future.
Thanks.
Thanks for the report - we definitely should have a docstring for that function. Fortunately, it's pretty straightforward: it basically calls functools.reduce() over tree_leaves(tree): https://github.com/google/jax/blob/94aade035a5fdeb2d3ed6f1744fcf1fa16240b8c/jax/_src/tree_util.py#L248-L254
Would you be interested in contributing this docstring? If not, someone on the team can take care of it.
Right I didn't think of checking the code. Thanks a lot. And I will pull a request then. :-)
@yitongx do you want to send a pull request you change into a main branch? :)
Hi @yitongx
It looks like jax.tree_util.tree_reduce has been added to the JAX documentation with the PR #19588. The same is reflected in the documentation here: https://jax.readthedocs.io/en/latest/_autosummary/jax.tree_util.tree_reduce.html.
Could you please verify and confirm it.
Thank you.