vision icon indicating copy to clipboard operation
vision copied to clipboard

Enable to register custom transform kernel

Open sungchul1 opened this issue 2 years ago • 0 comments

🚀 The feature

It would be good to be able to register custom transform kernels in v2.function.

Motivation, pitch

If I want to register the transform's kernel, which is incompatible with built-in torchvision transforms and the functional API, and which uses built-in tv_tensor classes, it will be blocked by checking if it is from built-in function. https://github.com/pytorch/vision/blob/6a9b5492d9590b19fe75300d95e3d9c4852a14ac/torchvision/transforms/v2/functional/_utils.py#L77-L84 https://github.com/pytorch/vision/blob/6a9b5492d9590b19fe75300d95e3d9c4852a14ac/torchvision/transforms/v2/functional/_utils.py#L92-L93

def custom_transform_kernel(inpt: torch.Tensor) -> torch.Tensor:
    ...

@F.register_kernel(custom_transform_kernel, tv_tensors.TVTensor)
def _custom_transform_kernel_dispatch(inpt: tv_tensors.TVTensor) -> tv_tensors.TVTensor:
    output = custom_transform_kernel(inpt.as_subclass(torch.Tensor))
    return tv_tensors.wrap(output, like=inpt)

class CustomTransform(tvt_v2.Transform):
    def _transform(self, inpt: Any, params: dict[str, Any]) -> Any:
        return self._call_kernel(custom_transform_kernel, inpt)

It would be more flexible if registering incompatible custom transform kernel is possible.

Alternatives

I tried to use @F._utils._register_kernel_internal instead, and it works. But I think it could not be a safe way.

def custom_transform_kernel(inpt: torch.Tensor) -> torch.Tensor:
    ...

@F._utils._register_kernel_internal(custom_transform_kernel, tv_tensors.TVTensor, tv_tensor_wrapper=False)
def _custom_transform_kernel_dispatch(inpt: tv_tensors.TVTensor) -> tv_tensors.TVTensor:
    output = custom_transform_kernel(inpt.as_subclass(torch.Tensor))
    return tv_tensors.wrap(output, like=inpt)

class CustomTransform(tvt_v2.Transform):
    def _transform(self, inpt: Any, params: dict[str, Any]) -> Any:
        return self._call_kernel(custom_transform_kernel, inpt)

Additional context

No response

sungchul1 avatar Jan 19 '24 08:01 sungchul1