WEIGHT_FORMAT_WARN in RNN.cpp does not get set on rocm
🚀 The feature, motivation and pitch
I am working on enabling test_nn.py test_cudnn_weight_format on rocm, an observed that the test works if
diff --git a/test/test_nn.py b/test/test_nn.py
index c8311c91d7..85b391e880 100644
--- a/test/test_nn.py
+++ b/test/test_nn.py
@@ -8460,8 +8460,9 @@ tensor(..., device='meta', size=(1,), requires_grad=True)""")
with warnings.catch_warnings(record=True) as w:
output_noncontig = rnn(input, hx)
if first_warn:
- self.assertEqual(len(w), 1)
- self.assertIn('weights are not part of single contiguous chunk of memory', w[0].message.args[0])
+ if (torch.version.hip is None):
+ self.assertEqual(len(w), 1)
+ self.assertIn('weights are not part of single contiguous chunk of memory', w[0].message.args[0])
first_warn = False
warnings.resetwarnings()
output_noncontig[0].sum().backward()
This warning is generated from aten/src/ATen/native/cudnn/RNN.cpp
How can this test pass without bypassing the above checks ??
Alternatives
No response
Additional context
No response
This might simply be a result of the corresponding MIOpen file being out-of-sync with the cudnn file.
@alugorey has been looking into the RNN implementation for a different issue, but this issue might also get resolved if and when we sync up the MIOpen version with the CUDNN version, so co-assigning this to Andy.
@jithunnair-amd @bmedishe Jithun is correct. This error is a direct side effect of the lack of weight flattening in our version of RNNs.