functorch
functorch copied to clipboard
Support static_argnums for kwargs
I input static_argnums as follow.
(1,)
and I got error
RuntimeError: Found an argument of type int at index 1. Non-tensor arguments must be marked static. Please set the static_argnums correctly to mark the argument at index 1 static.
The code that can reproduce:
import torch
from functorch.compile.aot_autograd import aot_function
from functorch.compile.compilers import ts_compile, default_decompositions
from transformers.models.bert.modeling_bert import BertEmbeddings, BertConfig
emb = BertEmbeddings(BertConfig.from_pretrained("bert-base-cased"))
config = {
"fw_compiler": ts_compile,
"bw_compiler": ts_compile,
"hasher_type": "StaticShapheHasher",
"decompositions": default_decompositions,
"static_argnums": (1,),
}
traced = aot_function(emb.forward, **config)
inputs = {"input_ids": torch.tensor([[1, 2, 3, 4]]).long(), "past_key_values_length": 1}
output = traced(**inputs)
Thanks @hyunwoongko for filing the issue. We also came across this type of issue quite recently.
I have assigned it to myself.
@anijain2305 Thanks for answering !
Hello @anijain2305 , has this issue been resolved? I'm running into the same problem. Thanks!