transformers-stream-generator
transformers-stream-generator copied to clipboard
generate 函数中 is_greedy_gen_mode 和 is_sample_gen_stream_mode 同时为 True
pip install transformers_stream_generator==0.0.4 后调试 llama 时,发现若使用如下命令
tokens = None
for token in torch_model.generate(
input_ids=input_ids,
max_length=1024,
num_beams=1,
num_return_sequences=1,
no_repeat_ngram_size=15,
repetition_penalty=1,
temperature=0.65,
do_stream=True):
if tokens is None:
tokens = token
else:
tokens = torch.cat((tokens, token)) # pylint: disable=no-member
words = tokenizer.decode(tokens)
yield words
会使得 NewGenerationMixin.generate(……) 函数中 is_greedy_gen_mode 和 is_sample_gen_stream_mode 同时为 True,这会使得 ~/.local/lib/python3.8/site-packages/transformers_stream_generator/main.py 里直接进入 382 行的 if is_greedy_gen_mode 块中 return self.greedy_search(……),导致无法正常流式输出。
为解决该问题,将~/.local/lib/python3.8/site-packages/transformers_stream_generator/main.py 里 292 行之后各个非 stream 的 is_xxx_mode 后添加 “and generation_config.do_stream is False”,如下图所示,就可以了
可以在下一个版本中进行修改
谢谢,如果方便的话,您可以帮我提一个PR么