[Feature Request] stream output support for mlx_lm
Hi, mlx developers.
First and foremost, I would like to express my sincere gratitude for your efforts in developing this library. Thank you so much. I'm a beginner in programming, but thanks to the code in this repository, I was able to create a script to Chat with Local LLM myself.
Currently, I am creating my own stream output function, by refer the code of "generate" function in utils.py. But I would be very happy if stream output was officially supported. I think stream output is very important for the user experience during Chat.
Best regards,
This would definitely be nice to have for mlx-lm, and I think it would be a good first issue for anyone who would like to contribute to the mlx-lm package. :)
This would definitely be nice to have for mlx-lm, and I think it would be a good first issue for anyone who would like to contribute to the mlx-lm package. :)
Thank you for the reply.
Should I have submitted the source code? If so, I apologize. Below is the code that I wrote (almost all copied) function based on "generate" function in utils.py. It worked in my environment But I don't think this should be used. It end up managing functions with almost the same source code twice. I think it's better to merge in "generate" function. (Because of I'm beginner of coding, I don't want to modify the original code. So, I added the dedicated function)
def generate_stream(
model: nn.Module,
tokenizer: PreTrainedTokenizer,
prompt: str,
temp: float = 0.0,
max_tokens: int = 100,
):
"""
Generate text from the model as a stream.
Args:
model (nn.Module): The language model.
tokenizer (PreTrainedTokenizer): The tokenizer.
prompt (str): The string prompt.
temp (float): The temperature for sampling (default 0).
max_tokens (int): The maximum number of tokens (default 100).
"""
prompt = mx.array(tokenizer.encode(prompt))
tokens = []
for token, _ in zip(generate_step(prompt, model, temp), range(max_tokens)):
if token == tokenizer.eos_token_id:
break
tokens.append(token.item())
yield tokenizer.decode(token.item())
Best Regards,
P.S. Anyway, mlx_lm updates are very fast. I'm very grateful to all developers.
I haven't had a chance to look at what the best way is to add it, but if your implementation works and doesn't have too much repetitive code, you could try creating a PR. I can help review it and maybe @awni could also give some suggestions/feedback.
@mzbac
Please cancel about use my code. I found a very basic problem on this code.
I was assuming only Japanese output, so I didn't care about the separation between the output characters. Japanese sentence doesn't have space between words. But when considering English or other European characters, it needs to consider the words separation by space.
For example, LLM supposes outputted "Hello ! It ' s nice to meet you ." (I separated each tokens by single space) "Hello" and "!" are different token, so it is separated. But this is unwanted output. The desired output is "Hello! It's nice to meet you."
tokenizer.decode([list]) can output this desired output. So, current utils.generate function is no problem. But my code is outputting per single token for streaming. So, it needs to look for the way of concat decoded a few tokens properly.
I'm sorry, but I'm not sure how to solve this problem. Should we close this issue ticket as for now?
I was attempting to create a function to enable streaming myself and wrote very similar code to @gitkaz. I've tested this with microsoft/phi-2 on English language output and it has been working as expected.
def get_stream():
for token, prob in generate_step(encoded, model):
if token == tokenizer.eos_token_id:
break
res = tokenizer.decode(token.item())
yield res
for val in get_stream():
print(val, end="", flush=True)
Essentially, this section of the utils.py file of the mlx-examples repo serves as the inspiration for this implementation.
Since this issue seems 3+ weeks old, I will refine the function I have written and make a PR with it to add to the utils.py folder so that others can also enable streaming in their responses as soon as possible.
@awni please let me know if this is still relevant. I will be happy to contribute this feature!
There might be an issue when decoding the tokens one by one for the llama tokenizer, and also for certain Unicode output, it requires more than one token to decode specific characters. Other than that, the implementation looks correct to me.
In that case would it make sense to stream batches of tokens for llama (store and decode multiple at a time)? What would be a generally optimal way?
Yeah, I couldn't find a really good way to work around it. It seems that storing all generated tokens and decoding them for each generation is the easiest workaround, but it's kind of suboptimal. FYI the related issue in transformers repo: https://github.com/huggingface/transformers/issues/22710
For llama, this is suboptimal as you mentioned but works for now:
def get_stream():
res = ""
tokens = []
for token, prob in generate_step(encoded, model):
if token == tokenizer.eos_token_id:
break
inl = len(res)
tokens.append(token.item())
res = tokenizer.decode(tokens)
yield res[inl:]
for val in get_stream():
print(val, end="", flush=True)
Yeah, that would work. Just a small improvement: you don't have to save the decoded string; instead, you can just save the index for len(generated_text). Also, the implementation wouldn't be able to handle Unicode well. Some Unicode characters need more than one token to decode properly. For example, if you ask the LLM to generate an emoji, you will get the replacement character instead of the correct emoji. So, you may need to check whether the token needs to wait for the following token in order to decode correctly.
@mzbac Hi, I was planning to implement something like this
encoded_prompt = mx.array(tokenizer.encode(prompt))
tokens = []
for (token, _prob) , n in zip(generate_step(encoded_prompt, model , temp = temp,repetition_penalty = repetition_penalty,repetition_context_size = repetition_context_size,top_p = top_p) , range(max_tokens)):
if tokenizer.eos_token_id == token:
break
tokens.append(token)
res = tokenizer.decode(tokens)
if not ("�" == res) :
tokens = []
yield res
I checked on my machine and was able to print emoji successfully and works in English lang. Tested on llama 3
Thanks to the mlx_lm developers! It appears my request has already been implemented. I propose closing this request ticket.