tvm icon indicating copy to clipboard operation
tvm copied to clipboard

[AMP] refine AMP and the corresponding tests for bfloat16

Open yangulei opened this issue 3 years ago • 3 comments

This PR fixes issue https://github.com/apache/tvm/issues/12763, where some OP are marked to keep the original dtype but some of its input is bfloat16 while a Cast is missing. The AMP tests have also been refined to cover bfloat16 without accuracy checking.

Update: The accuracy checking in test_dnnl.py of bf16 vs fp32 is unstable and error-prone. Thus the accuracy checking is ignored if only one bf16 result present, i.e. only compare bf16 vs bf16 and fp32 vs fp32.

yangulei avatar Sep 15 '22 01:09 yangulei

@tvm-bot rerun

billishyahao avatar Sep 15 '22 07:09 billishyahao

Thanks for the patch, Youlei! I found a bunch of statement like "op->dtype.is_float() || op->dtype.is_bfloat16()" in tvm folder

Shall we simply add new float type definition in tvm/include/tvm/runtime/data_type.h to eliminate those statements?

  /*! \return whether type is a general float type, including float/float16/bfloat16. */
  bool is_general_float() const { return is_float() || is_bfloat16(); }

billishyahao avatar Sep 15 '22 07:09 billishyahao

@billishyahao I agree that is_general_float() could make the code cleaner, but not clearer. general float is a broader concept rather than IEEE float point plus bfloat16, for example, TensorFloat-32 is also a general float. I prefer to keep the expr like op->dtype.is_float() || op->dtype.is_bfloat16() as its more clear and specific.

yangulei avatar Sep 16 '22 00:09 yangulei

@masahi Could you help to review this? Thanks.

yangulei avatar Oct 27 '22 08:10 yangulei