fix: enable passing more parameters to the generation model
Related Issues
N/A
Proposed Changes:
When running the following code using PromptNode, I noticed that the defined generate_kwargs passed to prompt_node.run() are not successfully passed down to the underlying hugging face invocation layer (line 258 in hugging_face.py).
import torch
from haystack.nodes import PromptNode
from transformers import AutoModelForCausalLM
from transformers import AutoTokenizer
from transformers import (StoppingCriteria, StoppingCriteriaList)
model = AutoModelForCausalLM.from_pretrained(
'mosaicml/mpt-7b-chat',
trust_remote_code=True
)
tokenizer = AutoTokenizer.from_pretrained("mosaicml/mpt-7b-chat")
prompt_node = PromptNode("mosaicml/mpt-7b-chat", model_kwargs={"model":model, "tokenizer": tokenizer})
stop_tokens = ['<|endoftext|>', '<|im_end|>']
stop_token_ids = tokenizer.convert_tokens_to_ids(stop_tokens)
class StopOnTokens(StoppingCriteria):
def __call__(self, input_ids: torch.LongTensor,
scores: torch.FloatTensor, **kwargs) -> bool:
for stop_id in stop_token_ids:
if input_ids[0][-1] == stop_id:
return True
return False
generate_kwargs = {
'max_length': 2048,
'temperature': 0.3,
'top_p': 1.0,
'top_k': 0,
'use_cache': True,
'do_sample': True,
'eos_token_id': 0,
'pad_token_id': 0,
'stopping_criteria': StoppingCriteriaList([StopOnTokens()]),
}
prompt_template = prompt_template = """<|im_start|>system\nA conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.<|im_end|>\n{query}"""
results = prompt_node.run(query='Hello', prompt_template=prompt_template, generation_kwargs=generate_kwargs)
print(results)
The reason for this is that the kwargs were filtered out with the following lines of code in the invoke method in hugging_face.py:
model_input_kwargs = {
key: kwargs[key]
for key in [
"return_tensors",
"return_text",
"return_full_text",
"clean_up_tokenization_spaces",
"truncation",
"generation_kwargs",
"max_new_tokens",
"num_beams",
"do_sample",
"num_return_sequences",
"max_length",
]
if key in kwargs
The proposed solution is to add more kwargs to this list as shown below:
model_input_kwargs = {
key: kwargs[key]
for key in [
"return_tensors",
"return_text",
"return_full_text",
"clean_up_tokenization_spaces",
"truncation",
"generation_kwargs",
"max_new_tokens",
"num_beams",
"do_sample",
"num_return_sequences",
"max_length",
"temperature",
"eos_token_id",
"pad_token_id",
"stopping_criteria",
"use_cache",
"top_p",
"top_k"
]
if key in kwargs
}
The correct answer by running the code mentioned above is shown as follows:
({'results': ['! How can I help you today?'], 'invocation_context': {'query': 'Hello', 'max_length': 2048, 'temperature': 0.3, 'top_p': 1.0, 'top_k': 0, 'use_cache': True, 'do_sample': True, 'eos_token_id': 0, 'pad_token_id': 0, 'stopping_criteria': [<__main__.StopOnTokens object at 0x7f93dca787c0>], 'results': ['! How can I help you today?'], 'prompts': ['<|im_start|>system\nA conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.<|im_end|>\nHello']}}, 'output_1')
How did you test it?
By adding this line of code print(model_input_kwargs) right before the line 251 in hugging_face.py, we can verify whether the params defined in generate_kwargs are passed correctly to the actual model.
And also, the correct results returned by PromptNode.run() verify that the model call is right. Otherwise, the above-mentioned code couldn't run through PromptNode.run().
Notes for the reviewer
Since the changes are rather straightforward, I didn't add a unit test. But if this is needed, pls just let me know.
Checklist
- I have read the contributors guidelines and the code of conduct
- I have updated the related issue with new insights and changes
- I added unit tests and updated the docstrings
- I've used one of the conventional commit types for my PR title:
fix:,feat:,build:,chore:,ci:,docs:,style:,refactor:,perf:,test:. - I documented my code
- I ran pre-commit hooks and fixed any issue
Pull Request Test Coverage Report for Build 5865704156
- 0 of 0 changed or added relevant lines in 0 files are covered.
- 22 unchanged lines in 3 files lost coverage.
- Overall coverage decreased (-0.001%) to 48.062%
| Files with Coverage Reduction | New Missed Lines | % |
|---|---|---|
| utils/context_matching.py | 1 | 95.7% |
| errors.py | 10 | 77.67% |
| nodes/prompt/invocation_layer/hugging_face.py | 11 | 93.08% |
| <!-- | Total: | 22 |
| Totals | |
|---|---|
| Change from base Build 5859099232: | -0.001% |
| Covered Lines: | 11370 |
| Relevant Lines: | 23657 |
💛 - Coveralls
@vblagoje any feedback on this PR?
@vblagoje fyi, I removed the parameters that will modify the model config files and only include those that will affect the generation behavior, see this blog post here: https://huggingface.co/blog/how-to-generate .
@faaany , thanks for your patience. If possible, I'd like to remove this filter altogether. It was put in there back in the day when a single unrecognized parameter would blow up the generation. We should try if today's transformers are more tolerant first and if so remove it altogether because it has become a maintenance nightmare. Let's try that first, wdyt?
@faaany , thanks for your patience. If possible, I'd like to remove this filter altogether. It was put in there back in the day when a single unrecognized parameter would blow up the generation. We should try if today's transformers are more tolerant first and if so remove it altogether because it has become a maintenance nightmare. Let's try that first, wdyt?
Yeah! That's a brilliant suggestion. I think today's transformers is capable of this. Let me have it a try and update my PR.
@vblagoje I have updated my PR and removed the key argument filter. But I need to pop documents and query from the list, otherwise the unit tests `` in test_prompt_node.py will fail. Pls have a review.
@vblagoje In my opinion, it is the user's responsibility to pass the correct parameters to the pipeline. If we want to assist users in passing the correct parameters, we could add a white list as we did before, but then we will need to continuously maintain this white list, which we want to get rid of now. Adding a blacklist is not a good idea either, because basically, that could be anything. However, we should at least exclude those parameters introduced by our own code, e.g. the documents and query.
@vblagoje And to your proposal that we could add a more generic method to remove all prompt template parameters, I am not sure about this, because usually the documents and query will be filled into the prompt template before passing to invoke so normally these 2 parameters won't arrive at invoke at all. My assumption is that there might be a bug somewhere in the run_batch method or in the test method, which is out of the scope of this PR.
Ok @faaany let me dust out the debugger and recall the details of how everything works here. With some concrete findings, we can make better decisions.
@faaany I looked at this issue in detail and concluded the following. We currently pass all template input kwargs to the invocation layer invoke method. And we shouldn't do that. In the PromptNode prompt method, right after the template is filled with args and kwargs, we should remove the template-related kwargs from kwargs_copy passed to model invocation. After the invocation, we should put them back in. These prompt template kwargs are not needed in the invocation layer. We will then most likely be able to safely remove these kwargs filters that are being repeated in each invocation layer and have PromptNode users bear the responsibility for what they pass as LLM kwargs. @silvanocerza knows this code base well; we can expand the conversation both here and internally.
@faaany I looked at this issue in detail and concluded the following. We currently pass all template input kwargs to the invocation layer invoke method. And we shouldn't do that. In the PromptNode prompt method, right after the template is filled with args and kwargs, we should remove the template-related kwargs from kwargs_copy passed to model invocation. After the invocation, we should put them back in. These prompt template kwargs are not needed in the invocation layer. We will then most likely be able to safely remove these kwargs filters that are being repeated in each invocation layer and have PromptNode users bear the responsibility for what they pass as LLM kwargs. @silvanocerza knows this code base well; we can expand the conversation both here and internally.
Thanks a lot for the investigation! I agree with your proposal. Just one more comment from my side: for a better user experience, it might be a good idea to do a bit of translation work for those parameters that are shared by each model provider but with different names. For example, for stop words, openai uses stop, and hugging face uses stop_words. Let's say we want to stick with stop_words, then it might be good to translate the argument stop_words to stop in the OpenAIInvocationLayer.invoke(). This would still be meaningful in my opinion.
@vblagoje I have updated the code. Could you take a look?
How about we update the main with this simple and clear change first https://github.com/deepset-ai/haystack/compare/main...remove_template_params_from_kwargs add unit tests, and then after that, we can continue on with different changes on all all layers - where applicable? cc @silvanocerza
How about we update the main first with this simple and clear change first main...remove_template_params_from_kwargs add unit tests, and then after that, we can continue on with different changes on all all layers - where applicable? cc @silvanocerza
sure, I will wait till the other PR is merged.
@vblagoje it seems that the template-related params are not removed entirely.
when I run the following code snippet:
node = PromptNode(default_prompt_template="sentiment-analysis", output_variable="out")
pipe = Pipeline()
pipe.add_node(component=node, name="prompt_node", inputs=['Query'])
result = pipe.run(query="not relevant", documents=[Document("Berlin is an amazing city.")])
print(result)
it gives errors and when I print out the params passed to the transformers' pipeline, this is the result:
{'query': 'not relevant', 'num_return_sequences': 1, 'num_beams': 1, 'max_length': 100}
As you can see above, the query parameter is still there and passed to the transformers library.
ok, the problem is that we only remove the parameters existing in the PromptTemplate and query is not one of those parameters as can be seen from the prompt template for sentiment-analysis shown below:
PromptTemplate(name=sentiment-analysis, prompt_text=Please give a sentiment for this context. Answer with positive, negative or neutral. Context: {documents}; Answer:, prompt_params=['documents'])
since we actually pass 2 params query="not relevant", documents=[Document("Berlin is an amazing city.")] to pipe.run(), we got the error.
ok, the problem is that we only remove the parameters existing in the PromptTemplate and
queryis not one of those parameters as can be seen from the prompt template forsentiment-analysisshown below:PromptTemplate(name=sentiment-analysis, prompt_text=Please give a sentiment for this context. Answer with positive, negative or neutral. Context: {documents}; Answer:, prompt_params=['documents'])since we actually pass 2 params
query="not relevant", documents=[Document("Berlin is an amazing city.")]topipe.run(), we got the error.
Right, I see. You don't have to pass a query, I think. I'd rather not pass a query that forcefully remove some specific parameters in code.
@vblagoje Yeah, I agree. Then we would need to modify the existing unit test. So instead of sentiment analysis, we can use a question-answering prompt template. I already pushed my code. Could you take a look ?
As this change might have some unintended consequences @faaany , let me look through everything once again and perhaps add some tests. I'll update PR with my comments tomorrow.
@anakin87 would you please have a look at this PR? The objective is to remove kwargs filters in HFLocalInvocationLayer invoke. They change from release to release and often block users from using the latest pipeline parameters they need. The downside is that we remove "helper wheels" and shift the burden of responsibility to users - they become responsible for the params they pass from PromptNode.
@vblagoje @anakin87 I have updated the patch based on your comments. Could you take a final look? Many Thanks!
Thanks for these updates @faaany . I'm consulting internally about the potentially unforeseen impact of these changes. One of the changes we are introducing with this PR is the following: in pipelines where users pass query (or documents) parameter but query (or documents) is not used in the prompt template - we'll have an issue in the HF invocation layer. One could argue this is a user's responsibility - not to pass the query parameter, but that's debatable. Give us another day or two to assess the impact. 🙏
Thanks for these updates @faaany . I'm consulting internally about the potentially unforeseen impact of these changes. One of the changes we are introducing with this PR is the following: in pipelines where users pass query (or documents) parameter but query (or documents) is not used in the prompt template - we'll have an issue in the HF invocation layer. One could argue this is a user's responsibility - not to pass the query parameter, but that's debatable. Give us another day or two to assess the impact. 🙏
no problem. We need to assess how often this case occurs. If this is indeed a common case, then we need to decide whether we should manually remove them in the invoke method or leave the filter there and keep adding parameters in the filter list.
@vblagoje any update regarding this PR?