transformers-stream-generator icon indicating copy to clipboard operation
transformers-stream-generator copied to clipboard

generate 函数中 is_greedy_gen_mode 和 is_sample_gen_stream_mode 同时为 True

Open AIxyz opened this issue 2 years ago • 1 comments

pip install transformers_stream_generator==0.0.4 后调试 llama 时,发现若使用如下命令

    tokens = None
    for token in torch_model.generate(
            input_ids=input_ids,
            max_length=1024,
            num_beams=1,
            num_return_sequences=1,
            no_repeat_ngram_size=15,
            repetition_penalty=1,
            temperature=0.65,
            do_stream=True):
        if tokens is None:
            tokens = token
        else:
            tokens = torch.cat((tokens, token))  # pylint: disable=no-member
        words = tokenizer.decode(tokens)
        yield words

会使得 NewGenerationMixin.generate(……) 函数中 is_greedy_gen_mode 和 is_sample_gen_stream_mode 同时为 True,这会使得 ~/.local/lib/python3.8/site-packages/transformers_stream_generator/main.py 里直接进入 382 行的 if is_greedy_gen_mode 块中 return self.greedy_search(……),导致无法正常流式输出。

为解决该问题,将~/.local/lib/python3.8/site-packages/transformers_stream_generator/main.py 里 292 行之后各个非 stream 的 is_xxx_mode 后添加 “and generation_config.do_stream is False”,如下图所示,就可以了

image

可以在下一个版本中进行修改

AIxyz avatar Jun 25 '23 08:06 AIxyz

谢谢,如果方便的话,您可以帮我提一个PR么

LowinLi avatar Jun 25 '23 10:06 LowinLi