TorchSharp icon indicating copy to clipboard operation
TorchSharp copied to clipboard

Support for non Tensor types for `ScriptModule`

Open kaiidams opened this issue 3 years ago • 2 comments

#771 added some of types but not all. This GH issue is minor.

HuggingFace has scriptmodel option in which the returned model only accepts ordinal arguments and outputs a tuple instead of keyword arguments and output dict.

model = AutoModelForSequenceClassification.from_pretrained(\"distilbert-base-uncased-finetuned-sst-2-english\", torchscript=True)

PyTorch mobile's Java binding exports a wrapper of IValue (https://pytorch.org/javadoc/1.9.0/)

The Python script below makes a TorchScript file that uses more type for testing.

import torch
from torch import nn

# Used types:
#   bool
#   bool_list
#   dict_long_key
#   dict_string_key
#   double
#   double_list
#   list
#   long
#   long_list
#   null
#   string
#   tensor
#   tensor_list
#   tuple


class MyModule(nn.Module):
    @torch.jit.export
    def forward(
        self,
        x: torch.Tensor,
        d: float,
        n: int,
        s: str,
    ):
        return (
            True,
            {
                1: [False, True],
                2: 1.2,
                n: [2.3, d],
            },
            {
                "abc": [123, torch.arange(n), None],
                s: [456, 789]
            },
            [x, torch.arange(n)]
        )


module = torch.jit.script(MyModule())
print(module.forward.code)
module.save("ivalue_test.pt")
module = torch.jit.load("ivalue_test.pt")
print(module.code)

kaiidams avatar Oct 08 '22 16:10 kaiidams

Let's track progress:

  • [x] bool
  • [x] bool_list
  • [ ] dict_long_key
  • [ ] dict_string_key
  • [x] double
  • [x] double_list
  • [x] list
  • [x] long
  • [x] long_list
  • [ ] null
  • [ ] string
  • [x] tensor
  • [x] tensor_list
  • [x] tuple (tensor)
  • [x] tuple (general)

NiklasGustafsson avatar Oct 21 '22 18:10 NiklasGustafsson