maxtext icon indicating copy to clipboard operation
maxtext copied to clipboard

Fix prefill assertion

Open RissyRan opened this issue 1 year ago • 0 comments

Description

We should assert true length of prompt, instead of padded token length if max_prefill_predict_length less than lengh of prompt. Otherwise, it will provide unexpected output.

Test

Run decoding for llama2-7b model

example cmd

 python3 MaxText/decode.py MaxText/configs/base.yml run_name=test_quick base_output_directory=gs://runner-maxtext-logs per_device_batch_size=1 model_name='llama2-7b' ici_autoregressive_parallelism=4 max_prefill_predict_length=2  max_target_length=16 prompt="I love to" attention=dot_product 
  • using cmd max_prefill_predict_length=2, now we will get expected output
Traceback (most recent call last):
  File "/home/ranran/mixtral/maxtext/MaxText/decode.py", line 75, in <module>
    main(cfg)
  File "/home/ranran/mixtral/maxtext/MaxText/decode.py", line 39, in main
    assert true_length <= config.max_prefill_predict_length, "can't take too many tokens"
  • using cmd max_prefill_predict_length=4, no assertion errors.

RissyRan avatar May 21 '24 21:05 RissyRan