transformers icon indicating copy to clipboard operation
transformers copied to clipboard

[T5] Adding `model_parallel = False` to `T5ForTokenClassification` and `MT5ForTokenClassification`

Open retarfi opened this issue 1 year ago • 0 comments

What does this PR do?

Added self.model_parallel = False in T5ForTokenClassification and MT5ForTokenClassification (similar to #24684). It resolves a problem caused by model_parallel not being set (similar to #24682).

For reproduction

from transformers import AutoModelForTokenClassification, Trainer
model = AutoModelForTokenClassification.from_pretrained("google-t5/t5-small")
trainer = Trainer(model=model)

This produces the following error:

AttributeError                            Traceback (most recent call last)
Cell In [1], line 3
      1 from transformers import AutoModelForTokenClassification, Trainer
      2 model = AutoModelForTokenClassification.from_pretrained("google-t5/t5-small")
----> 3 trainer = Trainer(model=model)

File ~/.pyenv/versions/3.9.14/lib/python3.9/site-packages/transformers/trainer.py:425, in Trainer.__init__(self, model, args, data_collator, train_dataset, eval_dataset, tokenizer, model_init, compute_metrics, callbacks, optimizers, preprocess_logits_for_metrics)
    417 if model.__class__.__name__ in MODEL_MAPPING_NAMES:
    418     raise ValueError(
    419         f"The model you have picked ({model.__class__.__name__}) cannot be used as is for training: it only "
    420         "computes hidden states and does not accept any labels. You should choose a model with a head "
    421         "suitable for your task like any of the `AutoModelForXxx` listed at "
    422         "https://huggingface.co/docs/transformers/model_doc/auto"
    423     )
--> 425 if hasattr(model, "is_parallelizable") and model.is_parallelizable and model.model_parallel:
    426     self.is_model_parallel = True
    427 else:

File ~/.pyenv/versions/3.9.14/lib/python3.9/site-packages/torch/nn/modules/module.py:1269, in Module.__getattr__(self, name)
   1267     if name in modules:
   1268         return modules[name]
-> 1269 raise AttributeError("'{}' object has no attribute '{}'".format(
   1270     type(self).__name__, name))

AttributeError: 'T5ForTokenClassification' object has no attribute 'model_parallel'

Similar error also happens in MT5ForTokenClassification.

Before submitting

  • [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • [x] Did you read the contributor guideline, Pull Request section?
  • [ ] Was this discussed/approved via a Github issue or the forum? Please add a link to it if that's the case.
  • [ ] Did you make sure to update the documentation with your changes? Here are the documentation guidelines, and here are tips on formatting docstrings.
  • [ ] Did you write any new necessary tests?

Who can review?

@ArthurZucker and @younesbelkada

retarfi avatar May 11 '24 19:05 retarfi