aihwkit icon indicating copy to clipboard operation
aihwkit copied to clipboard

AnalogConv2d fails when using TT-v2

Open Zhaoxian-Wu opened this issue 1 year ago • 3 comments

Description

When I tried to use the TT-v2 algorithm to train the convolutional network, I got a Cuda error.

How to reproduce

After running the following main.py file, I got an error RuntimeError: CUDA_CALL Error 'an illegal memory access was encountered' at cuda_util.cu:653

# main.py
import torch
from aihwkit.nn import AnalogConv2d
from aihwkit.optim import AnalogSGD
from aihwkit.simulator.configs import build_config
from aihwkit.simulator.configs.devices import SoftBoundsReferenceDevice
DEVICE = 'cuda:0'

rpu_config = build_config('ttv2', device=SoftBoundsReferenceDevice())
model = AnalogConv2d(
    in_channels=1, out_channels=3, kernel_size=5, rpu_config=rpu_config
).to(DEVICE)

optimizer = AnalogSGD(model.parameters(), lr=0.1)
optimizer.regroup_param_groups(model)
# if I use images = torch.empty((128, 1, 32, 32)), even the forward fails
images = torch.ones((128, 1, 32, 32))
images = images.to(DEVICE)
output = model(images)
loss = output.norm()**2
loss.backward()
optimizer.step()

Besides, if I create torch.empty() instead of torch.one(), the forward clause model(images) never stops, I guess there could be some endless loop happening.

Other information

  • Pytorch version: 2.1.2
  • Package version: aihwkit-gpu 0.9.0
  • OS: Linux
  • Python version: 3.10.13
  • Conda version (or N/A) :23.11.0

Zhaoxian-Wu avatar Apr 08 '24 21:04 Zhaoxian-Wu

Hi @Zhaoxian-Wu, Thanks for reporting this issue. what GPU were you using when you encountered this issue?

kaoutar55 avatar May 08 '24 14:05 kaoutar55

Hi @Zhaoxian-Wu for the CUDA memory problem, it looked like the problem had to do with how to set the DEVICE. If I set it by DEVICE = torch.cuda.set_device(0) instead of DEVICE = torch.device('cuda:0') then I did not see the problem. I found the solution in this issue

kkvtran avatar May 08 '24 21:05 kkvtran

It tried the same technique with torch.empty and I did not see the hanging(looping) issue either. So this is torch problem. Let us know if you have any questions.

kkvtran avatar May 08 '24 21:05 kkvtran

@Zhaoxian-Wu do you still have this issue. if not, we can close this. Please let us know

kaoutar55 avatar Jun 07 '24 09:06 kaoutar55

Hi @Zhaoxian-Wu for the CUDA memory problem, it looked like the problem had to do with how to set the DEVICE. If I set it by DEVICE = torch.cuda.set_device(0) instead of DEVICE = torch.device('cuda:0') then I did not see the problem. I found the solution in this issue

I tried this solution and it worked! It seems to be the issue from Pytorch. Thank you for your help @kkvtran! It is weird for the Pytorch community to leave this issue for such a long time.

Zhaoxian-Wu avatar Jun 07 '24 21:06 Zhaoxian-Wu