doctr icon indicating copy to clipboard operation
doctr copied to clipboard

[Bug] SARNet half-precision error.

Open decadance-dance opened this issue 2 years ago • 0 comments

Bug description

When use sar_resnet31 I am getting this error: expected mat1 and mat2 to have the same dtype, but got: float != c10::Half. But some other models (I tried master and crnn_mobilenet_v3_large) works well. Maybe somewhere in the sar_resnet31 architecture you missed to convert a tensor to half-precision. (imho)

Code snippet to reproduce the bug

from doctr.models import ocr_predictor
from doctr.io import DocumentFile


input = DocumentFile.from_images("./four.jpeg")
model = ocr_predictor('db_resnet50', 'sar_resnet31', pretrained=True).cuda().half()

result = model(input)
print(result)

The input image: four

Error traceback

Downloading https://doctr-static.mindee.com/models?id=v0.7.0/sar_resnet31-9a1deedf.pt&src=0 to /home/dmytrodronov/.cache/doctr/models/sar_resnet31-9a1deedf.pt
221871104it [00:04, 46983566.62it/s]                                                                               
Traceback (most recent call last):
  File "/home/dmytrodronov/mlocr/doc_tr.py", line 11, in <module>
    result = model(input)
  File "/home/dmytrodronov/miniconda3/envs/ocr2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/dmytrodronov/miniconda3/envs/ocr2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/dmytrodronov/miniconda3/envs/ocr2/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/dmytrodronov/doctr/doctr/models/predictor/pytorch.py", line 122, in forward
    word_preds = self.reco_predictor([crop for page_crops in crops for crop in page_crops], **kwargs)
  File "/home/dmytrodronov/miniconda3/envs/ocr2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/dmytrodronov/miniconda3/envs/ocr2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/dmytrodronov/miniconda3/envs/ocr2/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/dmytrodronov/doctr/doctr/models/recognition/predictor/pytorch.py", line 77, in forward
    raw = [self.model(batch, return_preds=True, **kwargs)["preds"] for batch in processed_batches]
  File "/home/dmytrodronov/doctr/doctr/models/recognition/predictor/pytorch.py", line 77, in <listcomp>
    raw = [self.model(batch, return_preds=True, **kwargs)["preds"] for batch in processed_batches]
  File "/home/dmytrodronov/miniconda3/envs/ocr2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/dmytrodronov/miniconda3/envs/ocr2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/dmytrodronov/doctr/doctr/models/recognition/sar/pytorch.py", line 254, in forward
    decoded_features = _bf16_to_float32(self.decoder(features, encoded, gt=None if target is None else gt))
  File "/home/dmytrodronov/miniconda3/envs/ocr2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/dmytrodronov/miniconda3/envs/ocr2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/dmytrodronov/doctr/doctr/models/recognition/sar/pytorch.py", line 149, in forward
    hidden_state_init, cell_state_init = self.lstm_cell(prev_symbol, (hidden_state_init, cell_state_init))
  File "/home/dmytrodronov/miniconda3/envs/ocr2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/dmytrodronov/miniconda3/envs/ocr2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/dmytrodronov/miniconda3/envs/ocr2/lib/python3.10/site-packages/torch/nn/modules/rnn.py", line 1347, in forward
    ret = _VF.lstm_cell(
RuntimeError: expected mat1 and mat2 to have the same dtype, but got: float != c10::Half

Environment

Actually what I got for the version built form the main branch.

Collecting environment information...

Traceback (most recent call last):
  File "/home/dmytrodronov/collect_env.py", line 355, in <module>
    main()
  File "/home/dmytrodronov/collect_env.py", line 350, in main
    output = get_pretty_env_info()
  File "/home/dmytrodronov/collect_env.py", line 345, in get_pretty_env_info
    return pretty_str(get_env_info())
  File "/home/dmytrodronov/collect_env.py", line 241, in get_env_info
    doctr_str = doctr.__version__ if DOCTR_AVAILABLE else "N/A"
AttributeError: module 'doctr' has no attribute '__version__'

The supposed env:

Collecting environment information...

DocTR version: 0.8.0a0
TensorFlow version: N/A
PyTorch version: 2.1.2
OpenCV version: 4.9.0.80
OS: Ubuntu 20.04.4 LTS
Python version: 3.10.13
Is CUDA available (TensorFlow): N/A
Is CUDA available (PyTorch): Yes
CUDA runtime version: 12.0
GPU models and configuration: GPU 0: NVIDIA A30

Deep Learning backend

is_tf_available: False is_torch_available: True

decadance-dance avatar Jan 26 '24 09:01 decadance-dance