Dinov2 encoder returns matrix with none values
I tried to run demo/demo_match.py, and I got an error.
(roma) vgpu@nvlink:~/RoMa$ python demo/demo_match.py
2025-10-27 13:53:02.924 | INFO | romatch.models.model_zoo.roma_models:roma_model:61 - Using coarse resolution (560, 560), and upsample res (864, 1152)
Traceback (most recent call last):
File "/home/vgpu/RoMa/demo/demo_match.py", line 41, in <module>
im2_transfer_rgb = F.grid_sample(
^^^^^^^^^^^^^^
File "/home/vgpu/miniconda3/envs/roma/lib/python3.12/site-packages/torch/nn/functional.py", line 5108, in grid_sample
return torch.grid_sampler(input, grid, mode_enum, padding_mode_enum, align_corners)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: grid_sampler(): expected grid to have size 2 in last dimension, but got grid with sizes [1, 1, 864, 2302, 4]
After some debugging, I have found out that the dinov2 encoder returned a matrix with none values: https://github.com/Parskatt/RoMa/blob/a1494f87ab85485ad8a92a981876ca837e2334e4/romatch/models/encoders.py#L64
(Pdb) dinov2_features_16
{'x_norm_clstoken': tensor([[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan]], device='cuda:0',
dtype=torch.float16), 'x_norm_patchtokens': tensor([[[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan],
...,
[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan]],
[[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan],
...,
[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan]]], device='cuda:0',
dtype=torch.float16), 'x_prenorm': tensor([[[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan],
...,
[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan]],
[[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan],
...,
[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan]]], device='cuda:0',
dtype=torch.float16), 'masks': None}
Here are details of my environment.
(roma) vgpu@nvlink:~/RoMa$ nvcc -V
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2025 NVIDIA Corporation
Built on Fri_Feb_21_20:23:50_PST_2025
Cuda compilation tools, release 12.8, V12.8.93
Build cuda_12.8.r12.8/compiler.35583870_0
(roma) vgpu@nvlink:~/RoMa$ pip show torch
Name: torch
Version: 2.9.0
(roma) vgpu@nvlink:~/RoMa$ python -c "import torch; print(torch.version.cuda)"
12.8
(roma) vgpu@nvlink:~/RoMa$ python -V
Python 3.12.12
(roma) vgpu@nvlink:~/RoMa$ pip list
Package Version Editable project location
------------------------ ----------- -------------------------
albucore 0.0.24
albumentations 2.0.8
annotated-types 0.7.0
certifi 2025.10.5
charset-normalizer 3.4.4
click 8.3.0
contourpy 1.3.3
cycler 0.12.1
einops 0.8.1
filelock 3.20.0
fonttools 4.60.1
fsspec 2025.9.0
fused-local-corr 0.2.3
gitdb 4.0.12
GitPython 3.1.45
h5py 3.15.1
hf-xet 1.1.10
huggingface-hub 0.36.0
idna 3.11
Jinja2 3.1.6
kiwisolver 1.4.9
kornia 0.8.1
kornia_rs 0.1.9
loguru 0.7.3
MarkupSafe 3.0.3
matplotlib 3.10.7
mpmath 1.3.0
networkx 3.5
numpy 2.2.6
nvidia-cublas-cu12 12.8.4.1
nvidia-cuda-cupti-cu12 12.8.90
nvidia-cuda-nvrtc-cu12 12.8.93
nvidia-cuda-runtime-cu12 12.8.90
nvidia-cudnn-cu12 9.10.2.21
nvidia-cufft-cu12 11.3.3.83
nvidia-cufile-cu12 1.13.1.3
nvidia-curand-cu12 10.3.9.90
nvidia-cusolver-cu12 11.7.3.90
nvidia-cusparse-cu12 12.5.8.93
nvidia-cusparselt-cu12 0.7.1
nvidia-nccl-cu12 2.27.5
nvidia-nvjitlink-cu12 12.8.93
nvidia-nvshmem-cu12 3.3.20
nvidia-nvtx-cu12 12.8.90
opencv-python 4.12.0.88
opencv-python-headless 4.12.0.88
packaging 25.0
pillow 12.0.0
pip 25.2
platformdirs 4.5.0
poselib 2.0.5
protobuf 6.33.0
pydantic 2.12.3
pydantic_core 2.41.4
pyparsing 3.2.5
python-dateutil 2.9.0.post0
PyYAML 6.0.3
requests 2.32.5
romatch 0.1.1 /home/vgpu/RoMa
safetensors 0.6.2
scipy 1.16.2
sentry-sdk 2.42.1
setuptools 80.9.0
simsimd 6.5.3
six 1.17.0
smmap 5.0.2
stringzilla 4.2.1
sympy 1.14.0
timm 1.0.20
torch 2.9.0
torchvision 0.24.0
tqdm 4.67.1
triton 3.5.0
typing_extensions 4.15.0
typing-inspection 0.4.2
urllib3 2.5.0
wandb 0.22.2
wheel 0.45.1
How can I solve this problem? Thank you.
Haven't run into this myself, possibly due to torch.float16 overflows or underflows. It should (I think), be fine to run dinov2 in torch.bfloat16, could you try that?
I have changed this line https://github.com/Parskatt/RoMa/blob/a1494f87ab85485ad8a92a981876ca837e2334e4/demo/demo_match.py#L28 to
roma_model = roma_outdoor(device=device, coarse_res=560, upsample_res=(864, 1152), amp_dtype=torch.bfloat16)
However, the problem is not resolved.
I think dinov2 dtype is hardcoded
Also, what image are you sending in to the model? What are the pixel values?
I'm just sending in the images in the assets directory (toronto_A.jpg and toronto_B.jpg).
# toronto_A.jpg
np.array(im1)
array([[[130, 165, 194],
[130, 165, 194],
[129, 164, 195],
...,
[ 29, 51, 21],
[103, 128, 89],
[135, 160, 118]],
[[130, 166, 197],
[129, 165, 196],
[127, 163, 196],
...,
[ 91, 116, 74],
[142, 168, 123],
[144, 170, 125]],
[[129, 167, 200],
[126, 165, 199],
[124, 161, 197],
...,
[112, 139, 91],
[143, 170, 120],
[131, 158, 107]],
...,
[[156, 124, 85],
[154, 124, 86],
[169, 142, 104],
...,
[ 80, 73, 38],
[ 58, 52, 15],
[ 56, 50, 9]],
[[153, 121, 82],
[157, 127, 88],
[176, 149, 111],
...,
[101, 91, 51],
[103, 91, 44],
[102, 91, 37]],
[[147, 115, 76],
[156, 126, 88],
[180, 153, 115],
...,
[ 60, 47, 7],
[100, 88, 35],
[131, 118, 57]]], shape=(864, 1152, 3), dtype=uint8)
# toronto_B.jpg
np.array(im2)
array([[[ 14, 79, 147],
[ 14, 79, 147],
[ 14, 79, 147],
...,
[111, 126, 147],
[118, 133, 154],
[118, 133, 154]],
[[ 14, 79, 147],
[ 14, 79, 147],
[ 14, 79, 147],
...,
[120, 135, 154],
[135, 150, 169],
[145, 160, 180]],
[[ 14, 79, 147],
[ 14, 79, 147],
[ 14, 79, 147],
...,
[121, 137, 153],
[147, 163, 179],
[161, 176, 193]],
...,
[[ 65, 85, 36],
[ 85, 106, 57],
[ 80, 100, 51],
...,
[ 72, 96, 36],
[ 65, 89, 29],
[ 44, 68, 8]],
[[ 95, 113, 65],
[ 80, 98, 50],
[ 79, 97, 49],
...,
[ 92, 116, 56],
[ 87, 111, 51],
[ 70, 94, 34]],
[[ 91, 107, 60],
[ 81, 97, 50],
[ 78, 94, 47],
...,
[ 78, 102, 42],
[ 95, 119, 59],
[ 84, 108, 48]]], shape=(864, 1152, 3), dtype=uint8)