trak icon indicating copy to clipboard operation
trak copied to clipboard

TRAK on cuda version 11.8 have numerical issue

Open enkeejunior1 opened this issue 11 months ago • 2 comments

As the title mentioned, TRAK is experiencing numerical issues on the cuda version 11.8.

  • Minimum reproducible code
import torch
from trak.projectors import ProjectionType, AbstractProjector, CudaProjector

grad_dim = int(1e6)
projector = CudaProjector(
    grad_dim=grad_dim,
    proj_dim=32768,
    seed=42, 
    proj_type=ProjectionType.normal,
    device='cuda:0',
    max_batch_size=8,
)
grad = torch.randn(8, grad_dim, device='cuda:0')
proj = projector.project(grad, model_id=0)
print(proj)

>>> tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:0')
  • Q: How I installed TRAK?
conda create -n trak python=3.10.16
conda activate trak
pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
conda install cuda-nvcc=11.8 -c nvidia -y
conda install -c nvidia cuda-toolkit=11.8 -y
export CUDA_HOME=$CONDA_PREFIX
export PYTHONPATH=$CONDA_PREFIX/lib/python3.x/site-packages:$PYTHONPATH
pip install traker[fast]==0.3.2
Enviroment setting
### Environment Info
Python 3.8.20
Package                  Version
------------------------ ------------
fast_jl                  0.1.3
filelock                 3.13.1
fsspec                   2024.6.1
Jinja2                   3.1.4
MarkupSafe               2.1.5
mpmath                   1.3.0
networkx                 3.0
numpy                    1.24.1
nvidia-cublas-cu11       11.11.3.6
nvidia-cuda-cupti-cu11   11.8.87
nvidia-cuda-nvrtc-cu11   11.8.89
nvidia-cuda-runtime-cu11 11.8.89
nvidia-cudnn-cu11        9.1.0.70
nvidia-cufft-cu11        10.9.0.58
nvidia-curand-cu11       10.3.0.86
nvidia-cusolver-cu11     11.4.1.48
nvidia-cusparse-cu11     11.7.5.86
nvidia-nccl-cu11         2.20.5
nvidia-nvtx-cu11         11.8.86
pillow                   10.2.0
pip                      24.2
setuptools               75.1.0
sympy                    1.13.1
torch                    2.4.1+cu118
torchvision              0.19.1+cu118
tqdm                     4.67.1
traker                   0.3.2
triton                   3.0.0
typing_extensions        4.12.2
wheel                    0.44.0
# packages in environment at /data/yonghyun/anaconda3/envs/trak_118:
#
# Name                    Version                   Build  Channel
_libgcc_mutex             0.1                        main  
_openmp_mutex             5.1                       1_gnu  
bzip2                     1.0.8                h5eee18b_6  
ca-certificates           2024.12.31           h06a4308_0  
cuda-cccl_linux-64        12.8.55                       0    nvidia
cuda-command-line-tools   12.8.0                        0    nvidia
cuda-compiler             12.6.2                        0    nvidia
cuda-cudart               12.8.57                       0    nvidia
cuda-cudart-dev           12.8.57                       0    nvidia
cuda-cudart-dev_linux-64  12.8.57                       0    nvidia
cuda-cudart-static        12.8.57                       0    nvidia
cuda-cudart-static_linux-64 12.8.57                       0    nvidia
cuda-cudart_linux-64      12.8.57                       0    nvidia
cuda-cuobjdump            12.8.55                       0    nvidia
cuda-cupti                12.8.57                       0    nvidia
cuda-cupti-dev            12.8.57                       0    nvidia
cuda-cuxxfilt             12.8.55                       0    nvidia
cuda-documentation        12.4.127                      0    nvidia
cuda-driver-dev           12.8.57                       0    nvidia
cuda-driver-dev_linux-64  12.8.57                       0    nvidia
cuda-gdb                  12.8.55                       0    nvidia
cuda-libraries            12.8.0                        0    nvidia
cuda-libraries-dev        12.8.0                        0    nvidia
cuda-nsight               12.8.55                       0    nvidia
cuda-nvcc                 11.8.89                       0    nvidia
cuda-nvdisasm             12.8.55                       0    nvidia
cuda-nvml-dev             12.8.55                       0    nvidia
cuda-nvprof               12.8.57                       0    nvidia
cuda-nvprune              12.8.55                       0    nvidia
cuda-nvrtc                12.8.61                       0    nvidia
cuda-nvrtc-dev            12.8.61                       0    nvidia
cuda-nvtx                 12.8.55                       0    nvidia
cuda-nvvp                 12.8.57                       0    nvidia
cuda-opencl               12.8.55                       0    nvidia
cuda-opencl-dev           12.8.55                       0    nvidia
cuda-profiler-api         12.8.55                       0    nvidia
cuda-sanitizer-api        12.8.55                       0    nvidia
cuda-toolkit              11.8.0                        0    nvidia
cuda-tools                12.8.0                        0    nvidia
cuda-version              12.8                          3    nvidia
cuda-visual-tools         12.8.0                        0    nvidia
dbus                      1.13.18              hb2f20db_0  
expat                     2.6.4                h6a678d5_0  
fast-jl                   0.1.3                    pypi_0    pypi
filelock                  3.13.1                   pypi_0    pypi
fontconfig                2.14.1               h55d465d_3  
freetype                  2.12.1               h4a9f257_0  
fsspec                    2024.6.1                 pypi_0    pypi
gds-tools                 1.13.0.11                     0    nvidia
glib                      2.78.4               h6a678d5_0  
glib-tools                2.78.4               h6a678d5_0  
gmp                       6.3.0                h6a678d5_0  
icu                       73.1                 h6a678d5_0  
jinja2                    3.1.4                    pypi_0    pypi
ld_impl_linux-64          2.40                 h12ee557_0  
libcublas                 12.8.3.14                     0    nvidia
libcublas-dev             12.8.3.14                     0    nvidia
libcufft                  11.3.3.41                     0    nvidia
libcufft-dev              11.3.3.41                     0    nvidia
libcufile                 1.13.0.11                     0    nvidia
libcufile-dev             1.13.0.11                     0    nvidia
libcurand                 10.3.9.55                     0    nvidia
libcurand-dev             10.3.9.55                     0    nvidia
libcusolver               11.7.2.55                     0    nvidia
libcusolver-dev           11.7.2.55                     0    nvidia
libcusparse               12.5.7.53                     0    nvidia
libcusparse-dev           12.5.7.53                     0    nvidia
libffi                    3.4.4                h6a678d5_1  
libgcc-ng                 11.2.0               h1234567_1  
libglib                   2.78.4               hdc74915_0  
libgomp                   11.2.0               h1234567_1  
libiconv                  1.16                 h5eee18b_3  
libnpp                    12.3.3.65                     0    nvidia
libnpp-dev                12.3.3.65                     0    nvidia
libnvfatbin               12.8.55                       0    nvidia
libnvfatbin-dev           12.8.55                       0    nvidia
libnvjitlink              12.8.61                       1    nvidia
libnvjitlink-dev          12.8.61                       1    nvidia
libnvjpeg                 12.3.5.57                     0    nvidia
libnvjpeg-dev             12.3.5.57                     0    nvidia
libpng                    1.6.39               h5eee18b_0  
libstdcxx-ng              11.2.0               h1234567_1  
libuuid                   1.41.5               h5eee18b_0  
libxcb                    1.15                 h7f8727e_0  
libxkbcommon              1.0.1                h097e994_2  
libxml2                   2.13.5               hfdd30dd_0  
markupsafe                2.1.5                    pypi_0    pypi
mpmath                    1.3.0                    pypi_0    pypi
ncurses                   6.4                  h6a678d5_0  
networkx                  3.0                      pypi_0    pypi
nsight-compute            2025.1.0.14                   0    nvidia
nspr                      4.35                 h6a678d5_0  
nss                       3.89.1               h6a678d5_0  
numpy                     1.24.1                   pypi_0    pypi
nvidia-cublas-cu11        11.11.3.6                pypi_0    pypi
nvidia-cuda-cupti-cu11    11.8.87                  pypi_0    pypi
nvidia-cuda-nvrtc-cu11    11.8.89                  pypi_0    pypi
nvidia-cuda-runtime-cu11  11.8.89                  pypi_0    pypi
nvidia-cudnn-cu11         9.1.0.70                 pypi_0    pypi
nvidia-cufft-cu11         10.9.0.58                pypi_0    pypi
nvidia-curand-cu11        10.3.0.86                pypi_0    pypi
nvidia-cusolver-cu11      11.4.1.48                pypi_0    pypi
nvidia-cusparse-cu11      11.7.5.86                pypi_0    pypi
nvidia-nccl-cu11          2.20.5                   pypi_0    pypi
nvidia-nvtx-cu11          11.8.86                  pypi_0    pypi
ocl-icd                   2.3.2                h5eee18b_1  
openssl                   3.0.15               h5eee18b_0  
pcre2                     10.42                hebb0a14_1  
pillow                    10.2.0                   pypi_0    pypi
pip                       24.2             py38h06a4308_0  
python                    3.8.20               he870216_0  
readline                  8.2                  h5eee18b_0  
setuptools                75.1.0           py38h06a4308_0  
sqlite                    3.45.3               h5eee18b_0  
sympy                     1.13.1                   pypi_0    pypi
tk                        8.6.14               h39e8969_0  
torch                     2.4.1+cu118              pypi_0    pypi
torchvision               0.19.1+cu118             pypi_0    pypi
tqdm                      4.67.1                   pypi_0    pypi
traker                    0.3.2                    pypi_0    pypi
triton                    3.0.0                    pypi_0    pypi
typing-extensions         4.12.2                   pypi_0    pypi
wheel                     0.44.0           py38h06a4308_0  
xz                        5.4.6                h5eee18b_1  
zlib                      1.2.13               h5eee18b_1  

### NVIDIA Info
Wed Feb  5 15:12:37 2025       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.216.01             Driver Version: 535.216.01   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA H100 PCIe               Off | 00000000:2D:00.0 Off |                    0 |
| N/A   40C    P0              52W / 350W |     17MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+

enkeejunior1 avatar Feb 04 '25 06:02 enkeejunior1

I encounter the same error

Haruka1307 avatar Nov 07 '25 03:11 Haruka1307

I encounter the same error

Finally I use Basicprojector, it almost has the same speed

Haruka1307 avatar Nov 08 '25 11:11 Haruka1307