RuntimeError: Could not find response key token IDs when using bloom model and tokenizer to train
it work fine if i use gpt-j
i guess this because of tokenizer and this https://github.com/databrickslabs/dolly/blob/03bf3852daa42e6091a39483dda0714c02de7382/training/trainer.py#L52
any tips to adjust it so it can use other model than gpt-j ?
thanks
Yes probably the code as currently written depends on the way this particular tokenizer works. I suggest printing out response_token_ids and batch["labels"][i] in order to help identify why it isn't finding the sequence. You can separately run tokenizer = load_tokenizer() and then run tokenizer.decode on some parts of the sequence to debug. It would actually be helpful if the error included this in the formatted string.
I think we might be able to make the code more robust to this as well so it works on other tokenizers. I'll look into it. Can you share the repro steps? For example which model are you using?
@matthayes
i'm just changing this line to model that i want https://github.com/databrickslabs/dolly/blob/03bf3852daa42e6091a39483dda0714c02de7382/training/trainer.py#L35
so far i tried bloom and xglm,, got same error
I hit this too, but not sure if it was for the same reason. I have different input, and didn't format it exactly like the alpaca dataset. In particular, it will have to include ### Response\n for example.
I believe this is resolved in Matt's changes from a few days ago anyway.
I just tested this in bloom with recent code and was able to reproduce the error still.
I see what is causing the issue. For the bloom model the tokenizer is combining the newline at the end of ### Response:\n with the next character, resulting in a different character. This doesn't happen with the gpt-j tokenizer. As a result the token IDs for ### Response:\n are not found exactly.
I actually just merged a change today that I think will make this easier to fix. Currently ### Response: becomes just a single token, as I made it a special token. I think I can update this so that the single token is ### Response:\n. This should prevent the newline from being combined with what follows.
I've merged in the fix which enables this to use bloom.