Android: Tensor type is not very friendly to BFloat16
🐛 Describe the bug
After https://github.com/pytorch/executorch/issues/6284#issuecomment-2423431020 patch that original UTF-8 invalid character issue had fixed,there is a new issue in tensor type support.
Llama3.2 1B/3B Model BF16 dtype
A large number of the following logcat logs will appear during the session, and the output is truncated or not displayed
E 00:00:02.486051 executorch:tensor_util.h:482] Check failed (t.scalar_type() == dtype): Expected to find Half type, but tensor has type BFloat16
E 00:00:02.496141 executorch:tensor_util.h:482] Check failed (t.scalar_type() == dtype): Expected to find Float type, but tensor has type BFloat16
Llama3.2 1B/3B SpinQuant Model FP32 dtype Conversations and emoji can be output completely, without truncation or non-display, and logcat does not show the above errors
Versions
master version
cc @kirklandsign
cc @kirklandsign
Hi @JamePeng my understanding is that we don't support BF16 dtype in Java at all right? We need to add it. I don't think java has native BF16, so it's more like a byte buffer to pass data?
Yes, maybe some conversion work is needed for the BF16 type to make the model work better.
Basically need a Tensor_bf16 subtype like https://github.com/pytorch/executorch/blob/release/0.5/extension/android/src/main/java/org/pytorch/executorch/Tensor.java#L537-L567
and also something like https://github.com/pytorch/executorch/blob/release/0.5/extension/android/src/main/java/org/pytorch/executorch/Tensor.java#L691
Hi @JamePeng so it happens not only in android app, but also llama_runner binary/
Hi @JamePeng is this still an issue?