doctr
doctr copied to clipboard
[Bug] SARNet half-precision error.
Bug description
When use sar_resnet31 I am getting this error: expected mat1 and mat2 to have the same dtype, but got: float != c10::Half.
But some other models (I tried master and crnn_mobilenet_v3_large) works well.
Maybe somewhere in the sar_resnet31 architecture you missed to convert a tensor to half-precision. (imho)
Code snippet to reproduce the bug
from doctr.models import ocr_predictor
from doctr.io import DocumentFile
input = DocumentFile.from_images("./four.jpeg")
model = ocr_predictor('db_resnet50', 'sar_resnet31', pretrained=True).cuda().half()
result = model(input)
print(result)
The input image:
Error traceback
Downloading https://doctr-static.mindee.com/models?id=v0.7.0/sar_resnet31-9a1deedf.pt&src=0 to /home/dmytrodronov/.cache/doctr/models/sar_resnet31-9a1deedf.pt
221871104it [00:04, 46983566.62it/s]
Traceback (most recent call last):
File "/home/dmytrodronov/mlocr/doc_tr.py", line 11, in <module>
result = model(input)
File "/home/dmytrodronov/miniconda3/envs/ocr2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/dmytrodronov/miniconda3/envs/ocr2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/home/dmytrodronov/miniconda3/envs/ocr2/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/home/dmytrodronov/doctr/doctr/models/predictor/pytorch.py", line 122, in forward
word_preds = self.reco_predictor([crop for page_crops in crops for crop in page_crops], **kwargs)
File "/home/dmytrodronov/miniconda3/envs/ocr2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/dmytrodronov/miniconda3/envs/ocr2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/home/dmytrodronov/miniconda3/envs/ocr2/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/home/dmytrodronov/doctr/doctr/models/recognition/predictor/pytorch.py", line 77, in forward
raw = [self.model(batch, return_preds=True, **kwargs)["preds"] for batch in processed_batches]
File "/home/dmytrodronov/doctr/doctr/models/recognition/predictor/pytorch.py", line 77, in <listcomp>
raw = [self.model(batch, return_preds=True, **kwargs)["preds"] for batch in processed_batches]
File "/home/dmytrodronov/miniconda3/envs/ocr2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/dmytrodronov/miniconda3/envs/ocr2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/home/dmytrodronov/doctr/doctr/models/recognition/sar/pytorch.py", line 254, in forward
decoded_features = _bf16_to_float32(self.decoder(features, encoded, gt=None if target is None else gt))
File "/home/dmytrodronov/miniconda3/envs/ocr2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/dmytrodronov/miniconda3/envs/ocr2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/home/dmytrodronov/doctr/doctr/models/recognition/sar/pytorch.py", line 149, in forward
hidden_state_init, cell_state_init = self.lstm_cell(prev_symbol, (hidden_state_init, cell_state_init))
File "/home/dmytrodronov/miniconda3/envs/ocr2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/dmytrodronov/miniconda3/envs/ocr2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/home/dmytrodronov/miniconda3/envs/ocr2/lib/python3.10/site-packages/torch/nn/modules/rnn.py", line 1347, in forward
ret = _VF.lstm_cell(
RuntimeError: expected mat1 and mat2 to have the same dtype, but got: float != c10::Half
Environment
Actually what I got for the version built form the main branch.
Collecting environment information...
Traceback (most recent call last):
File "/home/dmytrodronov/collect_env.py", line 355, in <module>
main()
File "/home/dmytrodronov/collect_env.py", line 350, in main
output = get_pretty_env_info()
File "/home/dmytrodronov/collect_env.py", line 345, in get_pretty_env_info
return pretty_str(get_env_info())
File "/home/dmytrodronov/collect_env.py", line 241, in get_env_info
doctr_str = doctr.__version__ if DOCTR_AVAILABLE else "N/A"
AttributeError: module 'doctr' has no attribute '__version__'
The supposed env:
Collecting environment information...
DocTR version: 0.8.0a0
TensorFlow version: N/A
PyTorch version: 2.1.2
OpenCV version: 4.9.0.80
OS: Ubuntu 20.04.4 LTS
Python version: 3.10.13
Is CUDA available (TensorFlow): N/A
Is CUDA available (PyTorch): Yes
CUDA runtime version: 12.0
GPU models and configuration: GPU 0: NVIDIA A30
Deep Learning backend
is_tf_available: False is_torch_available: True