xla
xla copied to clipboard
Optimize as_strided_copy fast path to support offset
This PR optimizes the as_strided_copy fast path to support offset.
Negative case (happens in the default path of torch LSTM):
import torch
import torch_xla
def torch_xla_chunk_example(input_tensor):
chunks = torch.unsafe_chunk(input_tensor, 4, dim=1)
result0 = chunks[0]
result1 = chunks[1]
return torch.sigmoid(result0), torch.tanh(result1)
def create_example():
device = torch_xla.device()
torch.manual_seed(12)
input_tensor = torch.randn(10, 20, 30, dtype=torch.float32)
result0, result1 = torch_xla_chunk_example(input_tensor.to(device))
torch_xla.sync()
print(f"Chunk 0:\n{result0}\n")
print(f"Chunk 1:\n{result1}\n")
return result0
if __name__ == "__main__":
create_example()
HLO before opt:
ENTRY %IrToHlo.19 (p0.1: f32[10,20,30], p1.4: s64[10,5,30]) -> (f32[10,5,30], f32[10,5,30]) {
...
%slice.2 = f32[10,5,30]{2,1,0} slice(f32[10,20,30]{2,1,0} %p0.1), slice={[0:10], [0:5], [0:30]}, metadata={op_type="xla__select" op_name="xla__select"}
...
%constant.10 = s64[] constant(0), metadata={op_type="aten__take" op_name="aten__take"}
%broadcast.11 = s64[1500]{0} broadcast(s64[] %constant.10), dimensions={}, metadata={op_type="aten__take"}
%compare.12 = pred[1500]{0} compare(s64[1500]{0} %reshape.6, s64[1500]{0} %broadcast.11), direction=GE, metadata={op_type="aten__take" op_name="aten__take"}
%constant.7 = s64[] constant(6000), metadata={op_type="aten__take" op_name="aten__take"}
%broadcast.8 = s64[1500]{0} broadcast(s64[] %constant.7), dimensions={}, metadata={op_type="aten__take"}
%add.9 = s64[1500]{0} add(s64[1500]{0} %reshape.6, s64[1500]{0} %broadcast.8), metadata={op_type="aten__take"}
%select.13 = s64[1500]{0} select(pred[1500]{0} %compare.12, s64[1500]{0} %reshape.6, s64[1500]{0} %add.9), metadata={op_type="aten__take" op_name="aten__take"}
%convert.14 = u32[1500]{0} convert(s64[1500]{0} %select.13), metadata={op_type="aten__take" op_name="aten__take"}
%gather.15 = f32[1500]{0} gather(f32[6000]{0} %reshape.5, u32[1500]{0} %convert.14), offset_dims={}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1, slice_sizes={1}, metadata={op_type="aten__take" op_name="aten__take"}
...
}
-
result0W/O offset usedslice -
result1W/ offset usedtakewhich contains many ops
The 2 operations are very similar except the offset part. In fact the 2nd result can be directly sliced after adding the offset to the start parameter of slice. This PR is to support such optimization in as_strided_copy when the offset happens in the same dimension of slice.
The optimized HLO:
ENTRY %IrToHlo.7 (p0.1: f32[10,20,30]) -> (f32[10,5,30], f32[10,5,30]) {
...
%slice.2 = f32[10,5,30]{2,1,0} slice(f32[10,20,30]{2,1,0} %p0.1), slice={[0:10], [0:5], [0:30]}, metadata={op_type="xla__select" op_name="xla__select"}
...
%slice.4 = f32[10,5,30]{2,1,0} slice(f32[10,20,30]{2,1,0} %p0.1), slice={[0:10], [5:10], [0:30]}, metadata={op_type="xla__select" op_name="xla__select"}
...