TorchScript compatible UNet
Is your feature request related to a problem? Please describe.
Currently UNet can be exported using torch.jit.trace but it's device dependent.
Describe the solution you'd like
It would be much better if I can use torch.jit.script, e.g. I can export the model on a cpu device then I can load and use the model on a cuda or mps device.
Describe alternatives you've considered I could help implement a torchscript friendly UNet purely for deployment without breaking the original eager API.
Additional context n/a
Thanks a lot for the issue @ChenchaoZhao,
Could you maybe add a reproducible code snippet that shows how one should use torch.jit.script and why it doesn't work? I.e. what error you get
Thanks a lot for the issue @ChenchaoZhao,
Could you maybe add a reproducible code snippet that shows how one should use
torch.jit.scriptand why it doesn't work? I.e. what error you get
TorchScript is a subset of python with static typing. If there is a type miss match, a compile error will be triggered.
import diffusers
import torch
print(diffusers.__version__, torch.__version__)
# 0.11.1 1.13.1+cu116
UNet from tutorial
from diffusers import UNet2DModel
model = UNet2DModel(
sample_size=32, # the target image resolution
in_channels=3, # the number of input channels, 3 for RGB images
out_channels=3, # the number of output channels
layers_per_block=2, # how many ResNet layers to use per UNet block
block_out_channels=(128, 128, 256, 256, 512, 512), # the number of output channes for each UNet block
down_block_types=(
"DownBlock2D", # a regular ResNet downsampling block
"DownBlock2D",
"DownBlock2D",
"DownBlock2D",
"AttnDownBlock2D", # a ResNet downsampling block with spatial self-attention
"DownBlock2D",
),
up_block_types=(
"UpBlock2D", # a regular ResNet upsampling block
"AttnUpBlock2D", # a ResNet upsampling block with spatial self-attention
"UpBlock2D",
"UpBlock2D",
"UpBlock2D",
"UpBlock2D"
),
)
ts_model = torch.jit.script(model)
# traceback
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
[<ipython-input-8-28e20fd90ee4>](https://vxph8e3w77c-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab-20230119-060047-RC00_503107040#) in <module>
----> 1 torch.jit.script(model)
6 frames
[/usr/local/lib/python3.8/dist-packages/torch/jit/_script.py](https://vxph8e3w77c-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab-20230119-060047-RC00_503107040#) in script(obj, optimize, _frames_up, _rcb, example_inputs)
1284 if isinstance(obj, torch.nn.Module):
1285 obj = call_prepare_scriptable_func(obj)
-> 1286 return torch.jit._recursive.create_script_module(
1287 obj, torch.jit._recursive.infer_methods_to_compile
1288 )
[/usr/local/lib/python3.8/dist-packages/torch/jit/_recursive.py](https://vxph8e3w77c-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab-20230119-060047-RC00_503107040#) in create_script_module(nn_module, stubs_fn, share_types, is_tracing)
474 if not is_tracing:
475 AttributeTypeIsSupportedChecker().check(nn_module)
--> 476 return create_script_module_impl(nn_module, concrete_type, stubs_fn)
477
478 def create_script_module_impl(nn_module, concrete_type, stubs_fn):
[/usr/local/lib/python3.8/dist-packages/torch/jit/_recursive.py](https://vxph8e3w77c-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab-20230119-060047-RC00_503107040#) in create_script_module_impl(nn_module, concrete_type, stubs_fn)
536
537 # Actually create the ScriptModule, initializing it with the function we just defined
--> 538 script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
539
540 # Compile methods if necessary
[/usr/local/lib/python3.8/dist-packages/torch/jit/_script.py](https://vxph8e3w77c-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab-20230119-060047-RC00_503107040#) in _construct(cpp_module, init_fn)
613 """
614 script_module = RecursiveScriptModule(cpp_module)
--> 615 init_fn(script_module)
616
617 # Finalize the ScriptModule: replace the nn.Module state with our
[/usr/local/lib/python3.8/dist-packages/torch/jit/_recursive.py](https://vxph8e3w77c-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab-20230119-060047-RC00_503107040#) in init_fn(script_module)
514 else:
515 # always reuse the provided stubs_fn to infer the methods to compile
--> 516 scripted = create_script_module_impl(orig_value, sub_concrete_type, stubs_fn)
517
518 cpp_module.setattr(name, scripted)
[/usr/local/lib/python3.8/dist-packages/torch/jit/_recursive.py](https://vxph8e3w77c-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab-20230119-060047-RC00_503107040#) in create_script_module_impl(nn_module, concrete_type, stubs_fn)
540 # Compile methods if necessary
541 if concrete_type not in concrete_type_store.methods_compiled:
--> 542 create_methods_and_properties_from_stubs(concrete_type, method_stubs, property_stubs)
543 # Create hooks after methods to ensure no name collisions between hooks and methods.
544 # If done before, hooks can overshadow methods that aren't exported.
[/usr/local/lib/python3.8/dist-packages/torch/jit/_recursive.py](https://vxph8e3w77c-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab-20230119-060047-RC00_503107040#) in create_methods_and_properties_from_stubs(concrete_type, method_stubs, property_stubs)
391 property_rcbs = [p.resolution_callback for p in property_stubs]
392
--> 393 concrete_type._create_methods_and_properties(property_defs, property_rcbs, method_defs, method_rcbs, method_defaults)
394
395 def create_hooks_from_stubs(concrete_type, hook_stubs, pre_hook_stubs):
RuntimeError:
get_timestep_embedding(Tensor timesteps, int embedding_dim, bool flip_sin_to_cos=False, float downscale_freq_shift=1., float scale=1., int max_period=10000) -> Tensor:
Expected a value of type 'float' for argument 'downscale_freq_shift' but instead found type 'int'.
:
File "/usr/local/lib/python3.8/dist-packages/diffusers/models/embeddings.py", line 99
def forward(self, timesteps):
t_emb = get_timestep_embedding(
~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
timesteps,
self.num_channels,
Thanks a lot for the repro @ChenchaoZhao :-) I'll try to have a look at this soon. Also cc @patil-suraj as it's related to torch.compile
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.
@ChenchaoZhao Have you figured out how to export the unet to jit.script? I need this to try the https://github.com/alibaba/BladeDISC which could only work on scripted models.
Nope I couldn’t torch script it. There are too many things to change. Maybe torch compile might work
On Fri, May 26, 2023 at 12:41 bonlime @.***> wrote:
@ChenchaoZhao https://github.com/ChenchaoZhao Have you figured out how to export the unet to jit.script? I need this to try the https://github.com/alibaba/BladeDISC which could only work on scripted models.
— Reply to this email directly, view it on GitHub https://github.com/huggingface/diffusers/issues/2014#issuecomment-1564652700, or unsubscribe https://github.com/notifications/unsubscribe-auth/AIMFBOPVCGBSK6T3IPKWTMDXIDMLZANCNFSM6AAAAAAT5LRAH4 . You are receiving this because you were mentioned.Message ID: @.***>