The `transition_parser` in `Spacy` is not compatible with the use of cuda for inference
I am facing an issue where am trying to run a spacy based pipeline, using the en_core_web_trf:3.7.3 model, whereby the transition_parser seems to be placing tensors on cpu instead of the gpu as can be seen in the logs below:
2024-04-26 10:31:25,319 [mlserver.parallel] ERROR - An error occurred calling method 'predict' from model 'exemplar-relation-extraction-service'.
Traceback (most recent call last):
File "/home/adarga/app/server.py", line 191, in predict
for sentence_spacy, request in zip(sentences_spacy, requests, strict=False):
File "/opt/pysetup/.venv/lib/python3.11/site-packages/spacy/language.py", line 1618, in pipe
for doc in docs:
File "/opt/pysetup/.venv/lib/python3.11/site-packages/spacy/util.py", line 1703, in _pipe
yield from proc.pipe(docs, **kwargs)
File "spacy/pipeline/transition_parser.pyx", line 245, in pipe
File "/opt/pysetup/.venv/lib/python3.11/site-packages/spacy/util.py", line 1650, in minibatch
batch = list(itertools.islice(items, int(batch_size)))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/pysetup/.venv/lib/python3.11/site-packages/spacy/util.py", line 1703, in _pipe
yield from proc.pipe(docs, **kwargs)
File "spacy/pipeline/pipe.pyx", line 55, in pipe
File "/opt/pysetup/.venv/lib/python3.11/site-packages/spacy/util.py", line 1703, in _pipe
yield from proc.pipe(docs, **kwargs)
File "spacy/pipeline/pipe.pyx", line 55, in pipe
File "/opt/pysetup/.venv/lib/python3.11/site-packages/spacy/util.py", line 1703, in _pipe
yield from proc.pipe(docs, **kwargs)
File "spacy/pipeline/transition_parser.pyx", line 245, in pipe
File "/opt/pysetup/.venv/lib/python3.11/site-packages/spacy/util.py", line 1650, in minibatch
batch = list(itertools.islice(items, int(batch_size)))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/pysetup/.venv/lib/python3.11/site-packages/spacy/util.py", line 1703, in _pipe
yield from proc.pipe(docs, **kwargs)
File "spacy/pipeline/trainable_pipe.pyx", line 73, in pipe
File "/opt/pysetup/.venv/lib/python3.11/site-packages/spacy/util.py", line 1650, in minibatch
batch = list(itertools.islice(items, int(batch_size)))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/pysetup/.venv/lib/python3.11/site-packages/spacy/util.py", line 1703, in _pipe
yield from proc.pipe(docs, **kwargs)
File "/opt/pysetup/.venv/lib/python3.11/site-packages/spacy_curated_transformers/pipeline/transformer.py", line 210, in pipe
preds = self.predict(batch)
^^^^^^^^^^^^^^^^^^^
File "/opt/pysetup/.venv/lib/python3.11/site-packages/spacy_curated_transformers/pipeline/transformer.py", line 242, in predict
return self.model.predict(docs)
^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/pysetup/.venv/lib/python3.11/site-packages/thinc/model.py", line 334, in predict
return self._func(self, X, is_train=False)[0]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/pysetup/.venv/lib/python3.11/site-packages/spacy_curated_transformers/models/architectures.py", line 651, in transformer_model_forward
Y, backprop_layer = model.layers[0](docs, is_train=is_train)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/pysetup/.venv/lib/python3.11/site-packages/thinc/model.py", line 310, in __call__
return self._func(self, X, is_train=is_train)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/pysetup/.venv/lib/python3.11/site-packages/spacy_curated_transformers/models/with_non_ws_tokens.py", line 72, in with_non_ws_tokens_forward
Y_no_ws, backprop_no_ws = inner(tokens, is_train)
^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/pysetup/.venv/lib/python3.11/site-packages/thinc/model.py", line 310, in __call__
return self._func(self, X, is_train=is_train)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/pysetup/.venv/lib/python3.11/site-packages/thinc/layers/chain.py", line 54, in forward
Y, inc_layer_grad = layer(X, is_train=is_train)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/pysetup/.venv/lib/python3.11/site-packages/thinc/model.py", line 310, in __call__
return self._func(self, X, is_train=is_train)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/pysetup/.venv/lib/python3.11/site-packages/spacy_curated_transformers/models/with_strided_spans.py", line 108, in with_strided_spans_forward
output, bp = transformer(cast(TorchTransformerInT, batch), is_train=is_train)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/pysetup/.venv/lib/python3.11/site-packages/thinc/model.py", line 310, in __call__
return self._func(self, X, is_train=is_train)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/pysetup/.venv/lib/python3.11/site-packages/thinc/layers/pytorchwrapper.py", line 225, in forward
Ytorch, torch_backprop = model.shims[0](Xtorch, is_train)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/pysetup/.venv/lib/python3.11/site-packages/thinc/shims/pytorch.py", line 97, in __call__
return self.predict(inputs), lambda a: ...
^^^^^^^^^^^^^^^^^^^^
File "/opt/pysetup/.venv/lib/python3.11/site-packages/thinc/shims/pytorch.py", line 115, in predict
outputs = self._model(*inputs.args, **inputs.kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/pysetup/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/pysetup/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/pysetup/.venv/lib/python3.11/site-packages/curated_transformers/models/curated_transformer.py", line 37, in forward
return self.curated_encoder.forward(input_ids, attention_mask, token_type_ids)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/pysetup/.venv/lib/python3.11/site-packages/curated_transformers/models/roberta/encoder.py", line 46, in forward
embeddings = self.embeddings(input_ids, token_type_ids, None)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/pysetup/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/pysetup/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/pysetup/.venv/lib/python3.11/site-packages/curated_transformers/models/roberta/embeddings.py", line 42, in forward
return self.inner(input_ids, token_type_ids, position_ids)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/pysetup/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/pysetup/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/pysetup/.venv/lib/python3.11/site-packages/curated_transformers/models/bert/embeddings.py", line 61, in forward
input_embeddings = self.word_embeddings(input_ids)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/pysetup/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/pysetup/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/pysetup/.venv/lib/python3.11/site-packages/torch/nn/modules/sparse.py", line 163, in forward
return F.embedding(
^^^^^^^^^^^^
File "/opt/pysetup/.venv/lib/python3.11/site-packages/torch/nn/functional.py", line 2206, in embedding
return handle_torch_function(
^^^^^^^^^^^^^^^^^^^^^^
File "/opt/pysetup/.venv/lib/python3.11/site-packages/torch/overrides.py", line 1604, in handle_torch_function
result = mode.__torch_function__(public_api, types, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/pysetup/.venv/lib/python3.11/site-packages/torch/utils/_device.py", line 77, in __torch_function__
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/opt/pysetup/.venv/lib/python3.11/site-packages/torch/nn/functional.py", line 2237, in embedding
return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument index in method wrapper_CUDA__index_select)
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/opt/pysetup/.venv/lib/python3.11/site-packages/mlserver/parallel/worker.py", line 136, in _process_request
return_value = await method(
^^^^^^^^^^^^^
File "/home/adarga/app/server.py", line 219, in predict
raise InferenceError(f"Error during relation extraction: {e}") from e
mlserver.errors.InferenceError: Error during relation extraction: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument index in method wrapper_CUDA__index_select)
I tried multiple fixes, such as using torch.set_default_device("cuda:0"), and torch.set_default_dtype, but this doesn't seem to be working.
How to reproduce the behaviour
This error is encountered using the model in an MLServer deployment. It is a bit difficult to provide reproduction code here.
Your Environment
- Operating System: nvidia/cuda:12.1.1-cudnn8-devel-ubuntu22.04
- Python Version Used: 3.11
- spaCy Version Used: 3.7.4[cupy-cuda12x]
- Environment Information: docker container running on an aws g4 instance
This is expected behaviour.
The transition parser involves making a prediction on each word of the sentence, and then making a state transition using the action predicted. This requires features from the current state, so the prediction cannot be made all at once across the sentence.
This sequence of small matrix multiplications is slow on GPU, so it's faster to do the whole-document feature extraction on GPU, and then copy the result over to the CPU to predict the transitions.
We've actually tried pretty extensively to get away from this, but the transition-based model is very good, and we can't match it with a more GPU-friendly approach. A key issue is that the transition-based approach is able to operate on unsegmented documents, so it can do joint sentence segmentation and parsing.
You can find the biaffine parser module in spacy-experimental, but we haven't yet released trained models for it.
Thanks for the explanation. That makes sense. The issue though is that using spacy pipelines becomes very difficult if said parser is part of the flow, which is our use case basically. I will close this now.
This thread has been automatically locked since there has not been any recent activity after it was closed. Please open a new issue for related bugs.