Fix the embedding dtype.
Description
The embedding attend() didn't cast the query. Refactor the code to fix the issue and make call() and attend() consistent.
Tests
Local tests passed.
Checklist
Before submitting this PR, please make sure (put X in square brackets):
- [x] I have performed a self-review of my code.
- [x] I have necessary comments in my code, particularly in hard-to-understand areas.
- [x] I have run end-to-end tests tests and provided workload links above if applicable.
- [x] I have made or will make corresponding changes to the doc if needed.
Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).
View this failed invocation of the CLA check for more information.
For the most up to date status, view the checks section at the bottom of the pull request.
This PR has been automatically marked as stale because it has not had recent activity. It will be closed soon if no further activity occurs. Thank you for your contributions.
This PR was closed because it has been inactive for a while. Please reopen it if you are still working on it.