Removed `with mesh` context manager from Flax jit guide, since it doesn't do anything with jax.jit
Check out this pull request on ![]()
See visual diffs & provide feedback on Jupyter Notebooks.
Powered by ReviewNB
Codecov Report
Merging #3303 (e331c38) into main (7f2426c) will not change coverage. Report is 1 commits behind head on main. The diff coverage is
100.00%.
@@ Coverage Diff @@
## main #3303 +/- ##
=======================================
Coverage 82.67% 82.67%
=======================================
Files 55 55
Lines 6342 6342
=======================================
Hits 5243 5243
Misses 1099 1099
| Files Changed | Coverage Δ | |
|---|---|---|
| flax/serialization.py | 94.56% <100.00%> (ø) |
Sorry for the late reply - I am actually debating whether to use with mesh or the mesh_sharding() util in the guide. It's quite common for larger model library to know nothing about the mesh, so maybe it will actually be cleaner if we annotate with the PartitionSpec and add mesh only on the top-level with with mesh.
Originally I made the guide with sharding passed everywhere because there was plan to deprecate jit/pjit API support to PartitionSpec, but this doesn't seem to be the case any more. Please let me know if you have any thoughts.
Ya I think removing the mesh_sharding function and just adding the mesh to NamedSharding(mesh, PartitionSpec(...)) would be more clear.
add mesh only on the top-level with
with mesh.
I spoke with @yashk2810 and he said that with mesh doesn't do anything with jax.jit
Is the with mesh: required with nn.with_logical_constraints() ?
(Under the hood this is calling lax.with_sharding_constraint())