dnnl_use_shift in dnnl_batch_normalization_backward_primitive_desc_create causes illegal argument error
Summary
dnnl_use_shift parameter in dnnl_batch_normalization_backward_primitive_desc_create causes illegal argument error when trying to execute the primitive. (dnnl_use_scale works as expected).
Version
3.2.1
Environment
Intel(R) Core(TM) i9-9900X CPU @ 3.50GHz Used through JavaCPP
Steps to reproduce
I'm using this through Clojure->JavaCPP-C layers, so there's no point of bothering you with the exact code, but the code that worked with dnnl_use_scaleshift, (in 2.9.x) now works only with dnnl_use_scale. When dnnl_use_shift is introduced, the execution of backward step throws an error.
Observed behavior
dnnl_use_shift breaks the code in any example that I could think of.
Expected behavior
The documentation (https://oneapi-src.github.io/oneDNN/dev_guide_batch_normalization.html) suggest that both dnnl_use_scale and dnnl_use_shift are supported. But this comment https://github.com/oneapi-src/oneDNN/pull/1440#discussion_r982974617 suggests that dnnl_use_shift was a problem at some point. I don't know how it was resolved, so I'm not sure what's the official expected behavior is at this moment (that's one reason why I'm opening this issue).
Just to add that although comment https://github.com/oneapi-src/oneDNN/pull/1440#discussion_r982974617 is related to aarch64, I spotted this issue on x86_64.
Based on @dzarukin 's comment in https://github.com/oneapi-src/oneDNN/pull/1440#discussion_r1318897877, the documentation for batch normalization shall be updated. Since it hasn't been done yet, I have created an internal Jira ticket to track this doc fix.