Enable to register custom transform kernel
🚀 The feature
It would be good to be able to register custom transform kernels in v2.function.
Motivation, pitch
If I want to register the transform's kernel, which is incompatible with built-in torchvision transforms and the functional API, and which uses built-in tv_tensor classes, it will be blocked by checking if it is from built-in function. https://github.com/pytorch/vision/blob/6a9b5492d9590b19fe75300d95e3d9c4852a14ac/torchvision/transforms/v2/functional/_utils.py#L77-L84 https://github.com/pytorch/vision/blob/6a9b5492d9590b19fe75300d95e3d9c4852a14ac/torchvision/transforms/v2/functional/_utils.py#L92-L93
def custom_transform_kernel(inpt: torch.Tensor) -> torch.Tensor:
...
@F.register_kernel(custom_transform_kernel, tv_tensors.TVTensor)
def _custom_transform_kernel_dispatch(inpt: tv_tensors.TVTensor) -> tv_tensors.TVTensor:
output = custom_transform_kernel(inpt.as_subclass(torch.Tensor))
return tv_tensors.wrap(output, like=inpt)
class CustomTransform(tvt_v2.Transform):
def _transform(self, inpt: Any, params: dict[str, Any]) -> Any:
return self._call_kernel(custom_transform_kernel, inpt)
It would be more flexible if registering incompatible custom transform kernel is possible.
Alternatives
I tried to use @F._utils._register_kernel_internal instead, and it works.
But I think it could not be a safe way.
def custom_transform_kernel(inpt: torch.Tensor) -> torch.Tensor:
...
@F._utils._register_kernel_internal(custom_transform_kernel, tv_tensors.TVTensor, tv_tensor_wrapper=False)
def _custom_transform_kernel_dispatch(inpt: tv_tensors.TVTensor) -> tv_tensors.TVTensor:
output = custom_transform_kernel(inpt.as_subclass(torch.Tensor))
return tv_tensors.wrap(output, like=inpt)
class CustomTransform(tvt_v2.Transform):
def _transform(self, inpt: Any, params: dict[str, Any]) -> Any:
return self._call_kernel(custom_transform_kernel, inpt)
Additional context
No response