maxtext
maxtext copied to clipboard
Fix the activation type in the checkpoint conversion script
Replace hard coded bfloat16 with config defined activation type.
Description
In this script, when creating the decoder state, the weight dtype is hardcoded to bfloat16. So the exported checkpoint is fixed bf16. However, it should be controlled by config so that unblocks the use cases of e.g. serving can leverage fp16 instead of bf16.
Tests
Manually run the conversion script for llama. The gemma script doesn't work even without this change.
Checklist
Before submitting this PR, please make sure (put X in square brackets):
- [x] I have performed a self-review of my code.
- [x] I have necessary comments in my code, particularly in hard-to-understand areas.
- [x] I have run end-to-end tests tests and provided workload links above if applicable.
- [x] I have made or will make corresponding changes to the doc if needed.