flax icon indicating copy to clipboard operation
flax copied to clipboard

Removed `with mesh` context manager from Flax jit guide, since it doesn't do anything with jax.jit

Open chiamp opened this issue 2 years ago • 5 comments

Removed with mesh context manager from Flax jit guide, since it doesn't do anything with jax.jit

chiamp avatar Aug 31 '23 21:08 chiamp

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

Codecov Report

Merging #3303 (e331c38) into main (7f2426c) will not change coverage. Report is 1 commits behind head on main. The diff coverage is 100.00%.

@@           Coverage Diff           @@
##             main    #3303   +/-   ##
=======================================
  Coverage   82.67%   82.67%           
=======================================
  Files          55       55           
  Lines        6342     6342           
=======================================
  Hits         5243     5243           
  Misses       1099     1099           
Files Changed Coverage Δ
flax/serialization.py 94.56% <100.00%> (ø)

codecov-commenter avatar Aug 31 '23 21:08 codecov-commenter

Sorry for the late reply - I am actually debating whether to use with mesh or the mesh_sharding() util in the guide. It's quite common for larger model library to know nothing about the mesh, so maybe it will actually be cleaner if we annotate with the PartitionSpec and add mesh only on the top-level with with mesh.

Originally I made the guide with sharding passed everywhere because there was plan to deprecate jit/pjit API support to PartitionSpec, but this doesn't seem to be the case any more. Please let me know if you have any thoughts.

IvyZX avatar Sep 12 '23 23:09 IvyZX

Ya I think removing the mesh_sharding function and just adding the mesh to NamedSharding(mesh, PartitionSpec(...)) would be more clear.

add mesh only on the top-level with with mesh.

I spoke with @yashk2810 and he said that with mesh doesn't do anything with jax.jit

chiamp avatar Sep 26 '23 00:09 chiamp

Is the with mesh: required with nn.with_logical_constraints() ?

(Under the hood this is calling lax.with_sharding_constraint())

andsteing avatar Sep 26 '23 06:09 andsteing