[ROCm] ROCm7 Plugin Updates
In preparation for the imminent release of ROCm 7.x, I have this set of changes, some of which are required for jax/jaxlib to load the new plugin namespace.
This seemed to be the simplest fix for loading multiple name/versions for the plugin, but I think it might be worth exploring something more dynamic like a callback or single entrypoint that jaxlib could call to discover plugin namespace from the jax_rocmN_plugin wheel itself. Maybe even using the python setuptools registrations like we do with PJRT, to do the discovery.
I am pushing this up now so we could get them into the next JAX release (presumably 0.6.2) and have support for the ROCm7 when it is published.
The bazel build changes are to fix an issue with missing symbols at load time, but also there is another problem with the rocm_config:hip target requiring amd_comgr which isn't correct either, but I have a PR in XLA to fix that. See https://github.com/openxla/xla/pull/27498
@hawkinsp Please let me know if this is an acceptable way to do the namespace load for now, and if we want to try to work out something more formal.
@mrodden seems like some CPU tests are failing, can you please check?
Error doesn't look related to my PR... I know Google Cloud was having problems last week so maybe related?
Seems like the CI job thats failing is due to the external Triton dependency?
@hawkinsp or @skye Looks like this is ready to merge and just needs internal Copybara approval
Can you please squash your commits? I think that might be causing the problem.
Manually merged and squashed all the changes into a rollup, along with the conflict resolution.
@hawkinsp should be ready again