cutlass icon indicating copy to clipboard operation
cutlass copied to clipboard

[QST] CUDA free failed when executing example 59_ampere_gather_scatter_conv

Open IzanCatalan opened this issue 1 year ago • 8 comments

What is your question? Hi, I am modifying the predefined variables of 59_ampere_gather_scatter_conv, these are my new values:

  using D = _1;
  using H = _4;
  using W = _4;

  using T = _1;
  using R = _1;
  using S = _1;

  using Z = _1;
  using P = _2;
  using Q = _2;

  using C = _32;
  using K = _32;

And here the result of the by-default execution:

izcagal@cmts10:~/cutlass/build$ ./examples/59_ampere_gather_scatter_conv/59_ampere_gather_scatter_conv --no-check
Ampere convolution forward propogation kernel supporting both affine and gather/scatter tensors.

Filter layout     ( K,        (C,T,R,S)) = (_32,(_32,_1,_1,_1)):(_32,(_1,_0,_0,_0))
Activation layout ((N,D,H,W), (C,1,1,1)) = ((4320,_1,_4,_4),(_32,_1,_1,_1)):((_512,_512,_128,_32),(_1,_0,_0,_0))
Output layout     ( K,        (N,Z,P,Q)) = (_32,(4320,_4,_2,_2)):(_1,(_512,_128,_64,_32))
Allocating tensors ... done.
Initializing data ... done.

Running dense fprop kernel
xformed act layout ((N,Z,P,Q), (C,T,R,S)) = ((4320,_4,_2,_2),(_32,_1,_1,_1)):((_512,_512,_128,_32),(_1,_512,_128,_32))
CUDA error at  (/mnt/beegfs/gap/izcagal/cutlass/examples/59_ampere_gather_scatter_conv/ampere_gather_scatter_conv.cu,155)
        700 -- an illegal memory access was encountered
Conv TFLOP count = 0.000142
Conv dense perf: 0.000000ms | TFLOP/s = inf
terminate called after throwing an instance of 'thrust::system::system_error'
  what():  CUDA free failed: cudaErrorIllegalAddress: an illegal memory access was encountered
Aborted (core dumped)

However, when I execute with --n=128, for some reason, it does not fail. I would like to know why this is happening. I am sure there are some restrictions about the variables, but the only one I found in the Convolution Implementation was that C or K must be a multiple of 32.

Any help would be appreciated.

Thanks.

IzanCatalan avatar Nov 27 '24 17:11 IzanCatalan

your config does not make sense to me.

TRS are all set to 1, so how could ZPQ be different than NDH

but the only one I found in the Convolution Implementation was that C or K must be a multiple of 32.

That is the 2.x implementation of conv and has nothing to do with example 59. This example has a lot more preconditions on its inputs as documented in the example readme

thakkarV avatar Nov 27 '24 19:11 thakkarV

@thakkarV Where can I find the preconditions for conv3d? And in addition, you said that it is a 2.x implementation of conv. Is there any other example for conv2d with cutlass 3.x? Maybe example 16?

IzanCatalan avatar Nov 28 '24 09:11 IzanCatalan

examples for CUTLASS 3.x based conv can be found here: https://github.com/NVIDIA/cutlass/blob/main/test/unit/conv/device_3x/fprop/sm90_conv2d_fprop_implicit_gemm_f16_f16_f32_tensorop_f16.cu

The preconditions of the main conv API are much more relaxed than example 59

thakkarV avatar Nov 30 '24 15:11 thakkarV

@thakkarV those examples are for a NVIDIA H100 Tensor Core GPU (SM90), my gpu is a NVIDIA A100 Tensor Core GPU and does not work with them (sm80), that is the reason I use example 16. Any help?

IzanCatalan avatar Dec 09 '24 17:12 IzanCatalan

we don't have first class API support for Ampere class GPUs via the 3.x API. The most we have is the custom one off kernel written in example 59 for Ampere class chips

thakkarV avatar Dec 09 '24 17:12 thakkarV

@thakkarV maybe not first-class API, but some source code, as is explained in Convolution Readme, and more specifically in device convolution?

IzanCatalan avatar Dec 09 '24 22:12 IzanCatalan

this example is the source code for Ampere convolutions via the 3.x API

thakkarV avatar Dec 17 '24 03:12 thakkarV

This issue has been labeled inactive-30d due to no recent activity in the past 30 days. Please close this issue if no further response or action is needed. Otherwise, please respond with a comment indicating any updates or changes to the original issue and/or confirm this issue still needs to be addressed. This issue will be labeled inactive-90d if there is no activity in the next 60 days.

github-actions[bot] avatar Jan 16 '25 04:01 github-actions[bot]

This issue has been labeled inactive-90d due to no recent activity in the past 90 days. Please close this issue if no further response or action is needed. Otherwise, please respond with a comment indicating any updates or changes to the original issue and/or confirm this issue still needs to be addressed.

github-actions[bot] avatar Apr 16 '25 04:04 github-actions[bot]