diffusers icon indicating copy to clipboard operation
diffusers copied to clipboard

TorchScript compatible UNet

Open ChenchaoZhao opened this issue 3 years ago • 4 comments

Is your feature request related to a problem? Please describe. Currently UNet can be exported using torch.jit.trace but it's device dependent.

Describe the solution you'd like It would be much better if I can use torch.jit.script, e.g. I can export the model on a cpu device then I can load and use the model on a cuda or mps device.

Describe alternatives you've considered I could help implement a torchscript friendly UNet purely for deployment without breaking the original eager API.

Additional context n/a

ChenchaoZhao avatar Jan 17 '23 03:01 ChenchaoZhao

Thanks a lot for the issue @ChenchaoZhao,

Could you maybe add a reproducible code snippet that shows how one should use torch.jit.script and why it doesn't work? I.e. what error you get

patrickvonplaten avatar Jan 20 '23 04:01 patrickvonplaten

Thanks a lot for the issue @ChenchaoZhao,

Could you maybe add a reproducible code snippet that shows how one should use torch.jit.script and why it doesn't work? I.e. what error you get

TorchScript is a subset of python with static typing. If there is a type miss match, a compile error will be triggered.

import diffusers
import torch


print(diffusers.__version__, torch.__version__)
# 0.11.1 1.13.1+cu116

UNet from tutorial


from diffusers import UNet2DModel


model = UNet2DModel(
    sample_size=32,  # the target image resolution
    in_channels=3,  # the number of input channels, 3 for RGB images
    out_channels=3,  # the number of output channels
    layers_per_block=2,  # how many ResNet layers to use per UNet block
    block_out_channels=(128, 128, 256, 256, 512, 512),  # the number of output channes for each UNet block
    down_block_types=( 
        "DownBlock2D",  # a regular ResNet downsampling block
        "DownBlock2D", 
        "DownBlock2D", 
        "DownBlock2D", 
        "AttnDownBlock2D",  # a ResNet downsampling block with spatial self-attention
        "DownBlock2D",
    ), 
    up_block_types=(
        "UpBlock2D",  # a regular ResNet upsampling block
        "AttnUpBlock2D",  # a ResNet upsampling block with spatial self-attention
        "UpBlock2D", 
        "UpBlock2D", 
        "UpBlock2D", 
        "UpBlock2D"  
      ),
)

ts_model = torch.jit.script(model)
# traceback
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
[<ipython-input-8-28e20fd90ee4>](https://vxph8e3w77c-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab-20230119-060047-RC00_503107040#) in <module>
----> 1 torch.jit.script(model)

6 frames
[/usr/local/lib/python3.8/dist-packages/torch/jit/_script.py](https://vxph8e3w77c-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab-20230119-060047-RC00_503107040#) in script(obj, optimize, _frames_up, _rcb, example_inputs)
   1284     if isinstance(obj, torch.nn.Module):
   1285         obj = call_prepare_scriptable_func(obj)
-> 1286         return torch.jit._recursive.create_script_module(
   1287             obj, torch.jit._recursive.infer_methods_to_compile
   1288         )

[/usr/local/lib/python3.8/dist-packages/torch/jit/_recursive.py](https://vxph8e3w77c-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab-20230119-060047-RC00_503107040#) in create_script_module(nn_module, stubs_fn, share_types, is_tracing)
    474     if not is_tracing:
    475         AttributeTypeIsSupportedChecker().check(nn_module)
--> 476     return create_script_module_impl(nn_module, concrete_type, stubs_fn)
    477 
    478 def create_script_module_impl(nn_module, concrete_type, stubs_fn):

[/usr/local/lib/python3.8/dist-packages/torch/jit/_recursive.py](https://vxph8e3w77c-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab-20230119-060047-RC00_503107040#) in create_script_module_impl(nn_module, concrete_type, stubs_fn)
    536 
    537     # Actually create the ScriptModule, initializing it with the function we just defined
--> 538     script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
    539 
    540     # Compile methods if necessary

[/usr/local/lib/python3.8/dist-packages/torch/jit/_script.py](https://vxph8e3w77c-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab-20230119-060047-RC00_503107040#) in _construct(cpp_module, init_fn)
    613             """
    614             script_module = RecursiveScriptModule(cpp_module)
--> 615             init_fn(script_module)
    616 
    617             # Finalize the ScriptModule: replace the nn.Module state with our

[/usr/local/lib/python3.8/dist-packages/torch/jit/_recursive.py](https://vxph8e3w77c-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab-20230119-060047-RC00_503107040#) in init_fn(script_module)
    514             else:
    515                 # always reuse the provided stubs_fn to infer the methods to compile
--> 516                 scripted = create_script_module_impl(orig_value, sub_concrete_type, stubs_fn)
    517 
    518             cpp_module.setattr(name, scripted)

[/usr/local/lib/python3.8/dist-packages/torch/jit/_recursive.py](https://vxph8e3w77c-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab-20230119-060047-RC00_503107040#) in create_script_module_impl(nn_module, concrete_type, stubs_fn)
    540     # Compile methods if necessary
    541     if concrete_type not in concrete_type_store.methods_compiled:
--> 542         create_methods_and_properties_from_stubs(concrete_type, method_stubs, property_stubs)
    543         # Create hooks after methods to ensure no name collisions between hooks and methods.
    544         # If done before, hooks can overshadow methods that aren't exported.

[/usr/local/lib/python3.8/dist-packages/torch/jit/_recursive.py](https://vxph8e3w77c-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab-20230119-060047-RC00_503107040#) in create_methods_and_properties_from_stubs(concrete_type, method_stubs, property_stubs)
    391     property_rcbs = [p.resolution_callback for p in property_stubs]
    392 
--> 393     concrete_type._create_methods_and_properties(property_defs, property_rcbs, method_defs, method_rcbs, method_defaults)
    394 
    395 def create_hooks_from_stubs(concrete_type, hook_stubs, pre_hook_stubs):

RuntimeError: 

get_timestep_embedding(Tensor timesteps, int embedding_dim, bool flip_sin_to_cos=False, float downscale_freq_shift=1., float scale=1., int max_period=10000) -> Tensor:
Expected a value of type 'float' for argument 'downscale_freq_shift' but instead found type 'int'.
:
  File "/usr/local/lib/python3.8/dist-packages/diffusers/models/embeddings.py", line 99
    def forward(self, timesteps):
        t_emb = get_timestep_embedding(
                ~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
            timesteps,
            self.num_channels,

ChenchaoZhao avatar Jan 21 '23 23:01 ChenchaoZhao

Thanks a lot for the repro @ChenchaoZhao :-) I'll try to have a look at this soon. Also cc @patil-suraj as it's related to torch.compile

patrickvonplaten avatar Jan 23 '23 08:01 patrickvonplaten

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

github-actions[bot] avatar Feb 16 '23 15:02 github-actions[bot]

@ChenchaoZhao Have you figured out how to export the unet to jit.script? I need this to try the https://github.com/alibaba/BladeDISC which could only work on scripted models.

bonlime avatar May 26 '23 16:05 bonlime

Nope I couldn’t torch script it. There are too many things to change. Maybe torch compile might work

On Fri, May 26, 2023 at 12:41 bonlime @.***> wrote:

@ChenchaoZhao https://github.com/ChenchaoZhao Have you figured out how to export the unet to jit.script? I need this to try the https://github.com/alibaba/BladeDISC which could only work on scripted models.

— Reply to this email directly, view it on GitHub https://github.com/huggingface/diffusers/issues/2014#issuecomment-1564652700, or unsubscribe https://github.com/notifications/unsubscribe-auth/AIMFBOPVCGBSK6T3IPKWTMDXIDMLZANCNFSM6AAAAAAT5LRAH4 . You are receiving this because you were mentioned.Message ID: @.***>

ChenchaoZhao avatar Jul 14 '23 14:07 ChenchaoZhao