trak icon indicating copy to clipboard operation
trak copied to clipboard

TRAK on cuda version 12.1 have CUDA error

Open enkeejunior1 opened this issue 1 year ago • 0 comments

  • minimal code for reproduce the error:
import torch
import trak
from trak.projectors import ProjectionType, AbstractProjector, CudaProjector

print("trak.test_install:", trak.test_install(use_fast_jl=True))
grad_dim = int(1e6)
projector = CudaProjector(
    grad_dim=grad_dim,
    proj_dim=32768,
    seed=42, 
    proj_type=ProjectionType.normal,
    device='cuda:0',
    max_batch_size=8,
)
grad = torch.randn(8, grad_dim, device='cuda:0')
proj = projector.project(grad, model_id=0)
print(proj)
  • env installation code
pip install scikit-learn matplotlib einops ipykernel
pip install torch torchvision --index-url https://download.pytorch.org/whl/cu121

conda install cuda=12.1 -c nvidia
conda install cuda-nvcc=12.1 -c nvidia -y
conda install cuda-toolkit=12.1 -c nvidia -y
export CUDA_HOME=$CONDA_PREFIX
export PYTHONPATH=$CONDA_PREFIX/lib/python3.x/site-packages:$PYTHONPATH
pip install traker[fast]

enkeejunior1 avatar Feb 04 '25 08:02 enkeejunior1