[AMP] refine AMP and the corresponding tests for bfloat16
This PR fixes issue https://github.com/apache/tvm/issues/12763, where some OP are marked to keep the original dtype but some of its input is bfloat16 while a Cast is missing.
The AMP tests have also been refined to cover bfloat16 without accuracy checking.
Update:
The accuracy checking in test_dnnl.py of bf16 vs fp32 is unstable and error-prone. Thus the accuracy checking is ignored if only one bf16 result present, i.e. only compare bf16 vs bf16 and fp32 vs fp32.
@tvm-bot rerun
Thanks for the patch, Youlei! I found a bunch of statement like "op->dtype.is_float() || op->dtype.is_bfloat16()" in tvm folder
Shall we simply add new float type definition in tvm/include/tvm/runtime/data_type.h to eliminate those statements?
/*! \return whether type is a general float type, including float/float16/bfloat16. */
bool is_general_float() const { return is_float() || is_bfloat16(); }
@billishyahao
I agree that is_general_float() could make the code cleaner, but not clearer. general float is a broader concept rather than IEEE float point plus bfloat16, for example, TensorFloat-32 is also a general float. I prefer to keep the expr like op->dtype.is_float() || op->dtype.is_bfloat16() as its more clear and specific.
@masahi Could you help to review this? Thanks.