xla
xla copied to clipboard
[POC] Dynamic `expand` with `SymInt` implementation
This is the POC implementation of torch.Tensor.expand op based on the PyTorch SymInt POC implementation PR.
Action items to unblock:
- [x] Lower
expandwithSymIntinput parameter as a POC. This won't compile untilexpandwithSymIntis available in PyTorch - [x] Update upstream PyTorch when
expand.SymIntop (withc10::SymIntArrayRefinput signature is available) is ready - [x] Integrate
dynamic_iras a subclass ofXLANode - [x] Implement
SizeNodelowering usingxla::GetDimensionSizeAPI - [x] Verify/implement JIT SSA shape support for
expand.SymIntin PyTorch - [x] Verify/implement shape inference support for
expand.SymIntin PyTorch- PR: https://github.com/pytorch/pytorch/pull/77830
- Status: Blocked
- [x] Reimplement
DimensionNode::isDynamic()after PyTorch API support becomes available- https://github.com/pytorch/pytorch/issues/77909
- Support the same feature in PyTorch/XLA (this PR)
- [x] Implement
SymInt-related helper functions to improve reuse and modularity. - [x] Move
DimensionNodeclass to LTC core and enable multiple inheritance forSizenodes- https://github.com/pytorch/pytorch/pull/78088
- [x] [Testing] Add a C++ unit test for
expand.SymIntin PyTorch/XLA- https://github.com/pytorch/xla/issues/3589
- Temporarily unblocked
- [x] [Testing] implement support for
torch.nonzeroin PyTorch- PR: https://github.com/pytorch/pytorch/pull/77572
- Status: Merged
- [x] [Testing] implement support for
torch.nonzeroin PyTorch/XLA - [x]
is_symbolicAPI issue- https://github.com/pytorch/xla/issues/3680
Does this pr build locally on your end? Build on CI failed with conflicts.
This PR doesn't build at the moment because the upstream layer LTC doesn't yet have API support for expand with SymInt. @JackCaoG @Krovatkin
Update: The current unit test checks the expand.SymInt code path. It does not check the dynamic dimension propagation across a SymInt op since DimensionNode::isDynamic implementation is currently WIP.
CC @JackCaoG @Krovatkin