xla icon indicating copy to clipboard operation
xla copied to clipboard

Optimize as_strided_copy fast path to support offset

Open Zantares opened this issue 2 months ago • 0 comments

This PR optimizes the as_strided_copy fast path to support offset.

Negative case (happens in the default path of torch LSTM):

import torch
import torch_xla


def torch_xla_chunk_example(input_tensor):
    chunks = torch.unsafe_chunk(input_tensor, 4, dim=1)
    result0 = chunks[0]
    result1 = chunks[1]
    return torch.sigmoid(result0), torch.tanh(result1)


def create_example():
    device = torch_xla.device()
    
    torch.manual_seed(12)
    input_tensor = torch.randn(10, 20, 30, dtype=torch.float32)
    
    result0, result1 = torch_xla_chunk_example(input_tensor.to(device)) 
    torch_xla.sync()
    print(f"Chunk 0:\n{result0}\n")
    print(f"Chunk 1:\n{result1}\n")
    
    return result0


if __name__ == "__main__":
    create_example()

HLO before opt:

ENTRY %IrToHlo.19 (p0.1: f32[10,20,30], p1.4: s64[10,5,30]) -> (f32[10,5,30], f32[10,5,30]) {                                                                                   
  ...
  %slice.2 = f32[10,5,30]{2,1,0} slice(f32[10,20,30]{2,1,0} %p0.1), slice={[0:10], [0:5], [0:30]}, metadata={op_type="xla__select" op_name="xla__select"}
  ...
  %constant.10 = s64[] constant(0), metadata={op_type="aten__take" op_name="aten__take"}
  %broadcast.11 = s64[1500]{0} broadcast(s64[] %constant.10), dimensions={}, metadata={op_type="aten__take"}
  %compare.12 = pred[1500]{0} compare(s64[1500]{0} %reshape.6, s64[1500]{0} %broadcast.11), direction=GE, metadata={op_type="aten__take" op_name="aten__take"}
  %constant.7 = s64[] constant(6000), metadata={op_type="aten__take" op_name="aten__take"}
  %broadcast.8 = s64[1500]{0} broadcast(s64[] %constant.7), dimensions={}, metadata={op_type="aten__take"}
  %add.9 = s64[1500]{0} add(s64[1500]{0} %reshape.6, s64[1500]{0} %broadcast.8), metadata={op_type="aten__take"}
  %select.13 = s64[1500]{0} select(pred[1500]{0} %compare.12, s64[1500]{0} %reshape.6, s64[1500]{0} %add.9), metadata={op_type="aten__take" op_name="aten__take"}
  %convert.14 = u32[1500]{0} convert(s64[1500]{0} %select.13), metadata={op_type="aten__take" op_name="aten__take"}
  %gather.15 = f32[1500]{0} gather(f32[6000]{0} %reshape.5, u32[1500]{0} %convert.14), offset_dims={}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1, slice_sizes={1}, metadata={op_type="aten__take" op_name="aten__take"}
  ...
}
  • result0 W/O offset used slice
  • result1 W/ offset used take which contains many ops

The 2 operations are very similar except the offset part. In fact the 2nd result can be directly sliced after adding the offset to the start parameter of slice. This PR is to support such optimization in as_strided_copy when the offset happens in the same dimension of slice.

The optimized HLO:

ENTRY %IrToHlo.7 (p0.1: f32[10,20,30]) -> (f32[10,5,30], f32[10,5,30]) {
  ...
  %slice.2 = f32[10,5,30]{2,1,0} slice(f32[10,20,30]{2,1,0} %p0.1), slice={[0:10], [0:5], [0:30]}, metadata={op_type="xla__select" op_name="xla__select"}
  ...
  %slice.4 = f32[10,5,30]{2,1,0} slice(f32[10,20,30]{2,1,0} %p0.1), slice={[0:10], [5:10], [0:30]}, metadata={op_type="xla__select" op_name="xla__select"} 
  ...

Zantares avatar Dec 10 '25 13:12 Zantares