maxtext
maxtext copied to clipboard
Fix prefill assertion
Description
We should assert true length of prompt, instead of padded token length if max_prefill_predict_length less than lengh of prompt. Otherwise, it will provide unexpected output.
Test
Run decoding for llama2-7b model
example cmd
python3 MaxText/decode.py MaxText/configs/base.yml run_name=test_quick base_output_directory=gs://runner-maxtext-logs per_device_batch_size=1 model_name='llama2-7b' ici_autoregressive_parallelism=4 max_prefill_predict_length=2 max_target_length=16 prompt="I love to" attention=dot_product
- using cmd
max_prefill_predict_length=2, now we will get expected output
Traceback (most recent call last):
File "/home/ranran/mixtral/maxtext/MaxText/decode.py", line 75, in <module>
main(cfg)
File "/home/ranran/mixtral/maxtext/MaxText/decode.py", line 39, in main
assert true_length <= config.max_prefill_predict_length, "can't take too many tokens"
- using cmd
max_prefill_predict_length=4, no assertion errors.