xla icon indicating copy to clipboard operation
xla copied to clipboard

Support `torch.Generator` on PyTorch/XLA

Open miladm opened this issue 11 months ago • 16 comments

🚀 Feature

Support torch.Generator on PyTorch/XLA backend.

miladm avatar May 13 '25 23:05 miladm

@yaochengji is there any context you'd like to include for this effort?

miladm avatar May 13 '25 23:05 miladm

@miladm thanks for asking, the context is that per-request sampling needs torch.Generator. https://github.com/vllm-project/vllm/blob/964472b9667508b1d4a7ed92068ff81740ae0036/vllm/v1/worker/gpu_model_runner.py#L364

yaochengji avatar May 14 '25 17:05 yaochengji

I am porting some code to TPU and am interested in this feature too.

wzhang313 avatar May 19 '25 18:05 wzhang313

if nobody else is looking at this issue. i can work on this

iwknow avatar Jun 19 '25 22:06 iwknow

after an investigation, this is my plan to add native torch.Generator on pytorch/xla. Basically, follow the example of CUDAGeneratorImpl.cpp

  1. Create a Custom XLA Generator Implementation: create a new C++ class that inherits from c10::GeneratorImpl. This class will be responsible for managing the state of the XLA RNG. This class should inherit from c10::GeneratorImpl and implement its virtual methods, including:
  • clone(): To clone the generator state.
  • current_seed(): To get the current seed.
  • manual_seed(uint64_t seed): To set the seed manually.
  • get_state() and set_state(const c10::TensorImpl& new_state): To get and set the generator's state.
  1. Register the XLA Generator: make the custom generator known to the PyTorch dispatcher, register it. This will associate the custom generator with the XLA device type.
  2. Integrate the Generator into XLA's Random Operations: Modify Random Functions like randn, bernoulli, normal to that they use the Generator to manage their random state and take an optional generator parameter for custom generator.
  3. Expose to Python: Add the necessary Python bindings so users can create and use torch.Generator(device='xla').

@tengyifei can you nominate a person to review the plan above to see if it is feasible and potentially be the code reviewer for this issue? thanks!

iwknow avatar Jun 23 '25 05:06 iwknow

@zhanyong-wan , @qihqi

tengyifei avatar Jun 23 '25 05:06 tengyifei

The plan sounds reasonable to me, but I'm not familiar with pytorch enough to judge it. I'll defer to @qihqi on this.

zhanyong-wan avatar Jun 23 '25 16:06 zhanyong-wan

@qihqi kindly ping

iwknow avatar Jun 27 '25 04:06 iwknow

@qihqi kindly ping

iwknow avatar Jul 01 '25 22:07 iwknow

@qihqi kindly ping

iwknow avatar Jul 07 '25 18:07 iwknow

This plan: https://github.com/pytorch/xla/issues/9159#issuecomment-2994942717 looks good to me. Feel free to start anytime. What support would you like from us (other than code reviews when ready)?

qihqi avatar Jul 08 '25 04:07 qihqi

Thanks @qihqi for approving the plan. i cannot think of any specific things that i need from you at the moment. i will ask for support if i need. thank you very much

iwknow avatar Jul 08 '25 16:07 iwknow

@qihqi i need some help on how to build and test the cpp files added for the xla generator. i added xla_generator.h and xla_generator.cpp file under torch_xla/csrc and added the two files to the target //torch_xla/csrc:tensor. however, i got the following error message when i try to build it bazel build //torch_xla/csrc:tensor.

ERROR: An error occurred during the fetch of repository 'python3_10_x86_64-unknown-linux-gnu':
   Traceback (most recent call last):
        File "/root/.cache/bazel/_bazel_root/0e9c69cead53c0f636ee8049b46e49c1/external/rules_python/python/repositories.bzl", line 196, column 25, in _python_repository_impl
                fail("The current user is root, please run as non-root when using the hermetic Python interpreter. See https://github.com/bazelbuild/rules_python/pull/713.")
Error in fail: The current user is root, please run as non-root when using the hermetic Python interpreter. See https://github.com/bazelbuild/rules_python/pull/713.
ERROR: /workspaces/xla_workspace/pytorch/WORKSPACE:227:27: fetching python_repository rule //external:python3_10_x86_64-unknown-linux-gnu: Traceback (most recent call last):
        File "/root/.cache/bazel/_bazel_root/0e9c69cead53c0f636ee8049b46e49c1/external/rules_python/python/repositories.bzl", line 196, column 25, in _python_repository_impl
                fail("The current user is root, please run as non-root when using the hermetic Python interpreter. See https://github.com/bazelbuild/rules_python/pull/713.")
Error in fail: The current user is root, please run as non-root when using the hermetic Python interpreter. See https://github.com/bazelbuild/rules_python/pull/713.
ERROR: Error computing the main repository mapping: Encountered error while reading extension file 'requirements.bzl': no such package '@pip_deps//': no such package '@python3_10_x86_64-unknown-linux-gnu//': The current user is root, please run as non-root when using the hermetic Python interpreter. See https://github.com/bazelbuild/rules_python/pull/713.

Do i miss anything while setting up my dev environment? i followed this instruction https://github.com/pytorch/xla/blob/master/CONTRIBUTING.md to setup a dev container. the user in my container is "root@frankzliu".

iwknow avatar Jul 17 '25 05:07 iwknow

@qihqi i need some help on how to build and test the cpp files added for the xla generator. i added xla_generator.h and xla_generator.cpp file under torch_xla/csrc and added the two files to the target //torch_xla/csrc:tensor. however, i got the following error message when i try to build it bazel build //torch_xla/csrc:tensor.

ERROR: An error occurred during the fetch of repository 'python3_10_x86_64-unknown-linux-gnu':
   Traceback (most recent call last):
        File "/root/.cache/bazel/_bazel_root/0e9c69cead53c0f636ee8049b46e49c1/external/rules_python/python/repositories.bzl", line 196, column 25, in _python_repository_impl
                fail("The current user is root, please run as non-root when using the hermetic Python interpreter. See https://github.com/bazelbuild/rules_python/pull/713.")
Error in fail: The current user is root, please run as non-root when using the hermetic Python interpreter. See https://github.com/bazelbuild/rules_python/pull/713.
ERROR: /workspaces/xla_workspace/pytorch/WORKSPACE:227:27: fetching python_repository rule //external:python3_10_x86_64-unknown-linux-gnu: Traceback (most recent call last):
        File "/root/.cache/bazel/_bazel_root/0e9c69cead53c0f636ee8049b46e49c1/external/rules_python/python/repositories.bzl", line 196, column 25, in _python_repository_impl
                fail("The current user is root, please run as non-root when using the hermetic Python interpreter. See https://github.com/bazelbuild/rules_python/pull/713.")
Error in fail: The current user is root, please run as non-root when using the hermetic Python interpreter. See https://github.com/bazelbuild/rules_python/pull/713.
ERROR: Error computing the main repository mapping: Encountered error while reading extension file 'requirements.bzl': no such package '@pip_deps//': no such package '@python3_10_x86_64-unknown-linux-gnu//': The current user is root, please run as non-root when using the hermetic Python interpreter. See https://github.com/bazelbuild/rules_python/pull/713.

Do i miss anything while setting up my dev environment? i followed this instruction https://github.com/pytorch/xla/blob/master/CONTRIBUTING.md to setup a dev container. the user in my container is "root@frankzliu".

never mind. it turns out that i ran the build in a wrong directory...

iwknow avatar Jul 17 '25 21:07 iwknow

just a update that i am still on it, the first PR will be out soon

iwknow avatar Jul 30 '25 05:07 iwknow

i found a fundamental problem of implementing the XLA generator.

The generator base class https://github.com/pytorch/pytorch/blob/main/c10/core/GeneratorImpl.h assumes the generator's state has seed and offset. there are getter and setter methods for both values. However, the current implementations of all the XLA RNG functions set the state(offset) to zero and never advance it (constant Philox offset of 0) https://github.com/pytorch/xla/blob/241cd47f1499df53a82d6a94c2dcf39dd261eaa4/torch_xla/csrc/random.cpp#L121. The seed is generated by Linear Congruential Generator (LCG) base on the initial seed (see https://github.com/pytorch/xla/blob/master/torch_xla/csrc/xla_graph_executor.cpp#L143). In this model, Philox is not being used as a stateful generator. It's being used as a one-shot "hash-like" function that turns a key into a single pseudo-random output. The "seed" is the only stateful part of the system.

I am not sure the reason for not using offset in XLA but it makes the implementation of the generator tricky. what should i do for the offset. with the current implantation, we getter can always return 0. what should the setter do?

My suggestion is to let the offset setter throw an exception saying that the offset is not used and setter doesn't work. We keep the current logic and implement the generator. after having a working generator, we can go back to see if we want to refactor and add the support for the offset.

what do you think? @qihqi

iwknow avatar Aug 03 '25 06:08 iwknow