transformers icon indicating copy to clipboard operation
transformers copied to clipboard

Make LogitsProcessor compatible with torch.compile

Open zucchini-nlp opened this issue 1 year ago • 4 comments

What does this PR do?

Small part of the issue #28981 . This PR makes sure that Logits Processor and Stopping Criteria are compatible with torch.compile when fullgraph=True. The changes were tested with dummy inputs and logits and also with Llama. For now only the Processors used in generate were checked, those that are used in bark/whisper models can be checked later if needed.

The below processors are not compatible, exceptions will be added later:

  • EncoderNoRepeatNGramLogitsProcessor and NoRepeatNGramLogitsProcessor -> tries to get a value from dict, which is input dependent
  • PrefixConstrainedLogitsProcessor -> relies on user provided functions, which mostly probably are also input dependent
  • SequenceBiasLogitsProcessor will not work at the same time with NoBadWordsProcessor, only one needs to be defined -> both call the same _prepare_bias_variables, which leads to recompiling it the second time we call with new arguments. Can be fixed if we either merge them into one processor or separate as two distinct.
  • UnbatchedClassifierFreeGuidanceLogitsProcessor -> calls the model forward, current Llama with sdpa failed due to providing not None attention_mask.
  • MaxTimeCriteria -> uses built-in time.time()

FYI @gante

zucchini-nlp avatar Feb 14 '24 13:02 zucchini-nlp

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@gante I went through the comments and fixed where possible. I am wondering if it is a good idea to add warnings as I did? Maybe there is a better way to do it, so that the users do not see lots of unrelated warning. I guess not everyone will use compile to generate

zucchini-nlp avatar Feb 19 '24 10:02 zucchini-nlp

@gante Ready to review. I fixed tests and the generation utils to work with "cur_len", everything runs successfully in my machine.

zucchini-nlp avatar Feb 20 '24 17:02 zucchini-nlp

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

github-actions[bot] avatar Mar 22 '24 08:03 github-actions[bot]