Add state to logits processing
Logits processing is a powerful tool, particularly for using smaller language models for tasks such as named entity recognition. @seanmor5 started work in this area with https://github.com/elixir-nx/bumblebee/pull/354.
Whatever the approach, it will require some kind of state.
This pull request is a proposal to allow logits processors to be stateful.
This would enable the use of deterministic finite automata (DFAs) or pushdown automata (PDAs) for processing constrained grammars in logits processing. https://github.com/bitcrowd/bumblebee/pull/6 shows how this would be used. We will follow up on this PR if this approach is favoured.
@jonatanklosko Before we add more test and do further refactorings: Do you think this goes in the right direction? Please let me know if you have concerns or anything could be improved.
We might not want to vectorize all the state of the logits processors e.g. when we want to read from a shared state tensor while processing the vectorized logits we would otherwise have to duplicate the shared state tensor across the vectorized axis, right? We can instead vectorize only the state that needs vectorization inside the logits processor.
That's basically the reason for 2ba5e0adb32eaeda42a957a38048363acc21ea57, I'm not entirely sure if this is alright or if it has negative implications for defn.
We might not want to vectorize all the state of the logits processors e.g. when we want to read from a shared state tensor while processing the vectorized logits we would otherwise have to duplicate the shared state tensor across the vectorized axis, right? We can instead vectorize only the state that needs vectorization inside the logits processor.
That's basically the reason for 2ba5e0a, I'm not entirely sure if this is alright or if it has negative implications for
defn.
Correct, Bumblebee should not call vectorize on the logits processor state. Ideally we want vectorization to happen automatically.
For example, schedulers have a similar init, here's one of them:
https://github.com/elixir-nx/bumblebee/blob/bc1b4525620e9afe4b2542324b3520762e37fdc5/lib/bumblebee/diffusion/pndm_scheduler.ex#L97-L108
alpha_bars is generated as a flat tensor and it is shared state (not duplicated across batch). On the other hand, the caller (Bumblebee) can pass sample_template with vectorized axis and then empty = Nx.fill(sample_template, 0) would be vectorized state. What's nice is that the scheduler is not aware about the vectorization, and a non-vectoriezd input works just fine.
For this to work automatically though, we need something to derive state of off (like sample_template), so that it gets automatically vectorized. I'm not yet sure how it would look for the processor, I need to think more about this.
Sorry for the late reply, I was off last week :)
For this to work automatically though, we need something to derive state of off (like
sample_template), so that it gets automatically vectorized. I'm not yet sure how it would look for the processor, I need to think more about this.
Let's just pass sequence: Nx.vectorize(state.sequences, :batch) in the init context too. Depending on what per-sequence state the user creates, they may need to take special care to make it vectorization friendly (e.g. Nx.iota({2, 2}, vectorized_axes: sequence.vectorized_axes), or using Nx.broadcast_vectors), but I think it's fine.
@jonatanklosko thank you for the late night review. Please let me know what you think. I added two livebooks about logits processing in the last commit. They are not strictly related to state, but I found them useful to explain logits processing in talks. I could open up a separate PR for them if you like, it was just too tempting to include them :)