Issue setting huggingface.prompt_builder = 'llama2' when using sagemaker as client
So I'm building a class that can alternate between both the huggingface and sagemaker clients and I declare all my os.environs at the top of the class like so:
os.environ["AWS_ACCESS_KEY_ID"] = "<key_id>"
os.environ["AWS_SECRET_ACCESS_KEY"] = '<access_key>'
os.environ["AWS_DEFAULT_REGION"] = "us-east-1"
os.environ["HF_TOKEN"] = "<hf_token>"
os.environ["HUGGINGFACE_PROMPT"] = "llama2"
and even later on in the class, just to be sure, I declare huggingface.prompt_builder = 'llama2' tried importing build_llama2_prompt directly and passing it as a callable, that also didn't work tried setting sagemaker.prompt_builder = 'llama2' just for fun to see if that would do anything...nope
Still get the warning telling me I haven't set a prompt builder, which is kinda weird, plus it's clear that occasionally the prompt is being formatted a bit weirdly (because the same prompt formatted as in the example below when passed directly to the sagemaker endpoint yields a somewhat better response from the same endpoint)
it's nbd that this doesn't work super well for me, I might just be being stupid about it, below is how I've just worked around it by manually implementing w/ sagemaker's HuggingFacePredictor cls:
llm = sagemaker.huggingface.model.HuggingFacePredictor('llama-party', sess)
def build_llama2_prompt(messages):
startPrompt = "<s>[INST] "
endPrompt = " [/INST]"
conversation = []
for index, message in enumerate(messages):
if message["role"] == "system" and index == 0:
conversation.append(f"<<SYS>>\n{message['content']}\n<</SYS>>\n\n")
elif message["role"] == "user":
conversation.append(message["content"].strip())
else:
conversation.append(f" [/INST] {message.content}</s><s>[INST] ")
return startPrompt + "".join(conversation) + endPrompt
prompt = build_llama2_prompt(messages)
payload = {
"inputs": prompt,
"parameters": {
"do_sample": True,
"top_p": 0.6,
"temperature": 0.9,
"top_k": 50,
"max_new_tokens": 512,
"repetition_penalty": 1.03,
"stop": ["</s>"]
}
}
chat = llm.predict(payload)
print(chat[0]["generated_text"][len(prompt):])
this code was pretty much fully taken from the sagemaker llama deployment blog post here: https://www.philschmid.de/sagemaker-llama-llm
works fine, just don't know why the same code doesn't work right inside of the lib (easyllm)
Can you check out this example for sagemaker? and see if it works? https://philschmid.github.io/easyllm/examples/sagemaker-chat-completion-api/#1-import-the-easyllm-library
Hi @bcarsley, were you able to resolve this?