Greedy sampling gives a warning message
System Info
With new version starting from 4.39, performing greedy search gives a warning: You should set do_sample=True or unset temperature. I am loading pretrained llama2-7b-chat-hf model. I understand that by default temperature is set to 0.6, so I explicitly set it to 0 while calling generate function. Something like this: model.generate(do_sample=False, temperature=0) But I get a warning message that either set do_sample to True or unset temperature Is the warning hampering the greedy decoding process or can I ignore it? (FYI I have tried with do_sample=True and top_k=1 which should ideally be same as greedy decoding. But just wanted to confirm if do_sample=True really gives greedy decoding results) If we cannot ignore it, then how should I unset the temperature?
Who can help?
@ArthurZucker
Information
- [ ] The official example scripts
- [ ] My own modified scripts
Tasks
- [ ] An officially supported task in the
examplesfolder (such as GLUE/SQuAD, ...) - [ ] My own task or dataset (give details below)
Reproduction
Step 1. Load the model model = "meta-llama/Llama-2-7b-chat-hf" tokenizer = AutoTokenizer.from_pretrained(model, token=access_token) model = AutoModelForCausalLM.from_pretrained(model, load_in_4bit=True, token=access_token) Step 2: Run the generation encoded_input = encoded_input.to(device) generated_text = model.generate(**encoded_input, max_new_tokens=4096, do_sample=False, top_k=1, temperature=0, top_p=0, return_dict_in_generate=True, output_scores=True)
Expected behavior
The model should the words with the highest probability at each generation step
cc @gante
While using the generate method if you set do_sample=False, you should not be setting the temperature because temperature is redundant when we set do_sample=False. Either you should skip the temperature parameter or you have to set it to 1.0. Take a look at the code snippet in the below link to see how they validate it here and generate a warning if the conditions are not met. https://github.com/huggingface/transformers/blob/594c1610fa6243b2ffb670c49faf389ca5121939/src/transformers/generation/configuration_utils.py#L533-L543
Once the model is loaded, the temperature is set to 0.6 by default and do_sample=True. I verified it by looking at the output of model.dict. So, the warning gets triggered on skipping temperature parameter (as it takes the default value of 0.6). My doubt is after I set do_sample=False:
-
How to unset the temperature (which is different from setting temperature to 0)?
-
Does the warning message actually messes up the greedy decoding? I believe validate is called from init of GenerationConfig
When do_sample is set to False, the temperature parameter is not used at all. Even though it seems to be set to 0.6, it would not be used. However, if you don't want to see the warning, you can set do_sample=True and temperature=1.0. In this case you would not see the warning and temperature would not be used either.
Also, I don't think the Warning messes up the greedy decoding.
Thanks for the quick reply. When do_sample is set to False, the temperature parameter is not used at all. Even though it seems to be set to 0.6, it would not be used.
This solution is all right for my purpose. But a code refactoring would be good to remove the warning. However, if you don't want to see the warning, you can set do_sample=True and temperature=1.0. In this case you would not see the warning and temperature would not be used either. AFAIK in theory setting temperature = 1.0 pushes the sampling distribution to be a bit random (as opposed to temperature = 0 which is essentially greedy decoding) Also, I don't think the Warning messes up the greedy decoding. If that's the case then the warning message can be ignored.
Could you provide a reproducer if you are still seeing the warning when you set do sample to false or temperature to 0?
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.