transformers-stream-generator
transformers-stream-generator copied to clipboard
sample_stream has errors when eos_token_id is a list more than one elements
main.py line 987 should be
origin code will multiply unfinished_sequences with a array which elements not only the 0 or 1
for example eos_token_id = [1,2] next_tokens = torch.LongTensor([1,2,3,4,5]) sum(next_tokens != i for i in eos_token_id).long() will be [1,1,2,2,2] which is not right