functorch icon indicating copy to clipboard operation
functorch copied to clipboard

Support static_argnums for kwargs

Open hyunwoongko opened this issue 3 years ago • 3 comments

I input static_argnums as follow.

(1,)

and I got error

RuntimeError: Found an argument of type int at index 1. Non-tensor arguments must be marked static. Please set the static_argnums correctly to mark the argument at index 1 static.

The code that can reproduce:

import torch

from functorch.compile.aot_autograd import aot_function
from functorch.compile.compilers import ts_compile, default_decompositions
from transformers.models.bert.modeling_bert import BertEmbeddings, BertConfig

emb = BertEmbeddings(BertConfig.from_pretrained("bert-base-cased"))

config = {
    "fw_compiler": ts_compile,
    "bw_compiler": ts_compile,
    "hasher_type": "StaticShapheHasher",
    "decompositions": default_decompositions,
    "static_argnums": (1,),
}

traced = aot_function(emb.forward, **config)
inputs = {"input_ids": torch.tensor([[1, 2, 3, 4]]).long(), "past_key_values_length": 1}
output = traced(**inputs)

hyunwoongko avatar Feb 09 '22 22:02 hyunwoongko

Thanks @hyunwoongko for filing the issue. We also came across this type of issue quite recently.

I have assigned it to myself.

anijain2305 avatar Feb 09 '22 23:02 anijain2305

@anijain2305 Thanks for answering !

hyunwoongko avatar Feb 10 '22 01:02 hyunwoongko

Hello @anijain2305 , has this issue been resolved? I'm running into the same problem. Thanks!

yuanandonly avatar Jul 05 '22 17:07 yuanandonly