Python bindings for scatter operations
Proposed changes
Expose scatter operations to the python API.
Any help on completing this is welcome.
It's up for discussion whether we should expose only the scatter operations with different modes or also the various scatter_add, scatter_prod,...
Checklist
Put an x in the boxes that apply.
- [x] I have read the CONTRIBUTING document
- [x] I have run
pre-commit run --all-filesto format my code / installed pre-commit prior to committing changes - [ ] I have added tests that prove my fix is effective or that my feature works
- [ ] I have updated the necessary documentation (if needed)
Ideally, we need examples in docs and tests for the binding. I will add some tests.
Hi, thanks for starting this, that's great! However if we are to expose scatter it would be similar to C++ via scatter_add, scatter_prod etc. Not via a general scatter that accepts a mode parameter.
Additionally, indexing is generally a more friendly interface so it might be worth mapping those to x[indices] += y rather than scatter_add for instance. Albeit as I mentioned in #394 it would require several shenanigans so it might not be worth it.
The most probably future path, imho, is to expose scatter_{op} so I think the work you are doing is not in vain but I am writing so that we can coordinate.
Hi @angeloskath, great - we're iterating on exposing each scatter_{op}. Do you think having also the generic scatter with the mode (or reduce) parameter (as @gboduljak implemented in https://github.com/francescofarina/mlx/pull/1) would be handy? That's often available in other libraries.
The most probably future path, imho, is to expose
scatter_{op}so I think the work you are doing is not in vain but I am writing so that we can coordinate.
As per @angeloskath's suggestion, I will work on exposing scatter_{op} instead of generic scatter.
As per @angeloskath's suggestion, I will work on exposing
scatter_{op}instead of genericscatter.
This is now done. @francescofarina please see https://github.com/francescofarina/mlx/pull/2/.
I'm slightly confused, there are two ongoing PRs for scatter ops in MLX 🤔 (#394 has bindings as well). It seems like they are going for different APIs, but o/w mostly the same?
Looks like this one can be safely closed and we can move with https://github.com/ml-explore/mlx/pull/394 which has a different API. @angeloskath @gboduljak ?
Yeah, sorry for stepping on your toes guys.
It kinda evolved from writing the gradients for scatter and scatter_add and then realizing that scatter_{op} would be kinda hard for people to understand (or rather me to document properly) so I implemented the array.at interface which is much simpler to understand.
Closing this in favor of #394. @francescofarina @gboduljak your feedback on #394 is appreciated.