pipeline 'text-classification' in >=4.40.0 throwing TypeError: Got unsupported ScalarType BFloat16
System Info
-
transformersversion: 4.40.1 - Platform: Linux-5.14.0-362.24.1.el9_3.x86_64-x86_64-with-glibc2.34
- Python version: 3.10.13
- Huggingface_hub version: 0.21.4
- Safetensors version: 0.4.2
- Accelerate version: 0.25.0
- Accelerate config: not found
- PyTorch version (GPU?): 2.1.2+cu121 (True)
- Tensorflow version (GPU?): not installed (NA)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using GPU in script?: yes, NVIDIA A10G
- Using distributed or parallel set-up in script?: No
Who can help?
No response
Information
- [ ] The official example scripts
- [X] My own modified scripts
Tasks
- [ ] An officially supported task in the
examplesfolder (such as GLUE/SQuAD, ...) - [X] My own task or dataset (give details below)
Reproduction
Does not occur in 4.39.3 - happens in >=4.40.0 and main. Appears to be related to PR #30518.
Test code is below (please ignore the lack of using a pre-trained sequence classification model):
model_id = "google/gemma-2b"
from transformers import AutoTokenizer, AutoModelForSequenceClassification
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token_id = tokenizer.eos_token_id
import torch
model = AutoModelForSequenceClassification.from_pretrained(
model_id,
device_map="auto",
torch_dtype=torch.bfloat16,
label2id={"LABEL_0": 0, "LABEL_1": 1},
num_labels=2,
)
from transformers import pipeline
classification_pipeline = pipeline("text-classification", model=model, tokenizer=tokenizer, return_all_scores=True)
predictions = classification_pipeline("test")
Traceback:
{
"name": "TypeError",
"message": "Got unsupported ScalarType BFloat16",
"stack": "---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Cell In[5], line 1
----> 1 predictions = pipeline(\"test\")
File ~/miniconda3/envs/mlenv/lib/python3.10/site-packages/transformers/pipelines/text_classification.py:156, in TextClassificationPipeline.__call__(self, inputs, **kwargs)
122 \"\"\"
123 Classify the text(s) given as inputs.
124
(...)
153 If `top_k` is used, one such dictionary is returned per label.
154 \"\"\"
155 inputs = (inputs,)
--> 156 result = super().__call__(*inputs, **kwargs)
157 # TODO try and retrieve it in a nicer way from _sanitize_parameters.
158 _legacy = \"top_k\" not in kwargs
File ~/miniconda3/envs/mlenv/lib/python3.10/site-packages/transformers/pipelines/base.py:1242, in Pipeline.__call__(self, inputs, num_workers, batch_size, *args, **kwargs)
1234 return next(
1235 iter(
1236 self.get_iterator(
(...)
1239 )
1240 )
1241 else:
-> 1242 return self.run_single(inputs, preprocess_params, forward_params, postprocess_params)
File ~/miniconda3/envs/mlenv/lib/python3.10/site-packages/transformers/pipelines/base.py:1250, in Pipeline.run_single(self, inputs, preprocess_params, forward_params, postprocess_params)
1248 model_inputs = self.preprocess(inputs, **preprocess_params)
1249 model_outputs = self.forward(model_inputs, **forward_params)
-> 1250 outputs = self.postprocess(model_outputs, **postprocess_params)
1251 return outputs
File ~/miniconda3/envs/mlenv/lib/python3.10/site-packages/transformers/pipelines/text_classification.py:205, in TextClassificationPipeline.postprocess(self, model_outputs, function_to_apply, top_k, _legacy)
202 function_to_apply = ClassificationFunction.NONE
204 outputs = model_outputs[\"logits\"][0]
--> 205 outputs = outputs.numpy()
207 if function_to_apply == ClassificationFunction.SIGMOID:
208 scores = sigmoid(outputs)
TypeError: Got unsupported ScalarType BFloat16"
}
Expected behavior
predictions = classification_pipeline("test") should return predictions.
cc @ArthurZucker
same question!
This is related to #28109. But ys a bit weird that we go to numpy(). Do you want to open a PR for this?
Any update on this? Getting the same result when using the pipeline class for question answering
A fix was merged and will be included in the release
@ArthurZucker which PR fixes this issue?
same problem
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.
#30999 is the fix!