transformers
transformers copied to clipboard
[T5] Adding `model_parallel = False` to `T5ForTokenClassification` and `MT5ForTokenClassification`
What does this PR do?
Added self.model_parallel = False in T5ForTokenClassification and MT5ForTokenClassification (similar to #24684).
It resolves a problem caused by model_parallel not being set (similar to #24682).
For reproduction
from transformers import AutoModelForTokenClassification, Trainer
model = AutoModelForTokenClassification.from_pretrained("google-t5/t5-small")
trainer = Trainer(model=model)
This produces the following error:
AttributeError Traceback (most recent call last)
Cell In [1], line 3
1 from transformers import AutoModelForTokenClassification, Trainer
2 model = AutoModelForTokenClassification.from_pretrained("google-t5/t5-small")
----> 3 trainer = Trainer(model=model)
File ~/.pyenv/versions/3.9.14/lib/python3.9/site-packages/transformers/trainer.py:425, in Trainer.__init__(self, model, args, data_collator, train_dataset, eval_dataset, tokenizer, model_init, compute_metrics, callbacks, optimizers, preprocess_logits_for_metrics)
417 if model.__class__.__name__ in MODEL_MAPPING_NAMES:
418 raise ValueError(
419 f"The model you have picked ({model.__class__.__name__}) cannot be used as is for training: it only "
420 "computes hidden states and does not accept any labels. You should choose a model with a head "
421 "suitable for your task like any of the `AutoModelForXxx` listed at "
422 "https://huggingface.co/docs/transformers/model_doc/auto"
423 )
--> 425 if hasattr(model, "is_parallelizable") and model.is_parallelizable and model.model_parallel:
426 self.is_model_parallel = True
427 else:
File ~/.pyenv/versions/3.9.14/lib/python3.9/site-packages/torch/nn/modules/module.py:1269, in Module.__getattr__(self, name)
1267 if name in modules:
1268 return modules[name]
-> 1269 raise AttributeError("'{}' object has no attribute '{}'".format(
1270 type(self).__name__, name))
AttributeError: 'T5ForTokenClassification' object has no attribute 'model_parallel'
Similar error also happens in MT5ForTokenClassification.
Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
- [x] Did you read the contributor guideline, Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the forum? Please add a link to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes? Here are the documentation guidelines, and here are tips on formatting docstrings.
- [ ] Did you write any new necessary tests?
Who can review?
@ArthurZucker and @younesbelkada