Add Classifier-Free Guidance sampling
Feature request
Hello! I wish to contribute CFG sampling. I'm working with EleutherAI and @StellaAthena and will have a paper about it by Friday. CFG brings non trivial improvements on many standard benchmarks. It contrast the logits for the next token $P(w_t|w_{..t}, prompt)$ to that of the input deprived of the prompt $P(w_t|w_{..t})$, by defining
$$ \log P_{\text{cfg}}(w|w_{..t}, prompt) = \log P(w|w_{..t}) + \text{cfg} * (\log P(w|w_{..t}, prompt) - \log P(w|w_{..t}) $$
And then we can blend $\log P_{\text{cfg}}$ with $\log P(w|w_{..t}, prompt)$ to smoothen that distribution a bit, but it's optional.
Motivation
My current implementation is:
class CFGLogits(LogitsWarper):
def __init__(self, cfg, inputs, model, verbose=True):
self.cfg = cfg
self.inputs = inputs
self.model = model
self.out = None
self.verbose = verbose
def __call__(self, input_ids, scores):
if self.cfg == 1:
return F.log_softmax(scores, dim=-1)
scores = F.log_softmax(scores, dim=-1)
if self.out is None:
self.out = self.model(self.inputs.to(device), use_cache=True)
else:
self.out = self.model(input_ids[:, -1:],
use_cache=True,
past_key_values=self.out.past_key_values)
unconditional_logits = F.log_softmax(self.out.logits[0][-1:], dim=-1)
out = self.cfg * (scores - unconditional_logits) + unconditional_logits
out = F.log_softmax(out, dim=-1)
return 0.7 * out + 0.3 * scores
# usage:
outputs = model.generate(
input_ids=inputs['input_ids'].to(device),
attention_mask=inputs['attention_mask'].to(device),
max_new_tokens=l,
logits_processor=LogitsProcessorList([
# inputs_cfg usually is the last token of the prompt but there are
# possibilities of negative prompting that are explored in the paper
CFGLogits(cfg, inputs_cfg, model),
TemperatureLogitsWarper(0.8),
TopPLogitsWarper(0.95),
]),
do_sample=True,
)
I am not familiar enough with the design guidelines of HF to know if this implementation as a LogitsWarper is satisfactory.
just a few figures supporting the claims:
Your contribution
I can contribute the code but I need to be guided as I don't know the exact design guidelines and overall architecture of HF.
Thank you for your time!
cc @gante But let's see if the community requests this added feature before implementing it in the library proper :-)
Hey @Vermeille 👋
I have the impression that our MusicGen PR (still open, expected to get merged soon) introduces the bulk of the logic to make it happen -- see this file
It is the same thing with a slightly different code implementation, correct? In the MusicGen PR, the model does a forward pass with 2x the batch size, where half of the batch corresponds to the unprompted tokens
Indeed @gante !
I don't fully get how the 2x batch size thing works, but if it does, it's cool. The paper makes some more additions to that base implementation:
- the
uncond_logitsmight in fact have a different prompt than thecond_logits, which is commonly called "negative prompt". - the comment says "usually at the expense of poorer quality". This can be mitigated with linearly interpolating the cfg
scoresback with with the initialscores - We had better results
log_softmaxing both scores before cfg, which normalizes both logits sets to a common "scale".
cc @sanchit-gandhi, who's probably better equipped to comment on potential differences :)
Hey @Vermeille - thanks for the comprehensive write-up! Just a clarifying question: in your implementation, how do you construct the token ids for the model based on the conditional ids and the un-conditional ones? You mention:
inputs_cfg usually is the last token of the prompt but there are
Which suggests you concatenate them together in the same batch item?
In MusicGen (and also the HF Diffusers library for models like Stable Diffusion), we construct our input ids by concatenating the input ids for the conditional prompt and the un-conditional prompt along the batch dimension (dim=0):
input_ids = torch.concatenate([conditional_ids, unconditional_ids], dim=0)
This is what's referred to by the 2x batch size 'trick' (concatenating the conditional prompt and unconditional prompt over the batch dim). There's no restriction to how these unconditional ids are formed - they can be from a 'null' input, or from a negative prompt. So we can do negative prompting in exactly the way you've described.
When we run our model forward, the logits for the first half of the batch corresponds to the conditional prompt, and the second half to the unconditional prompt (or negative prompt if we use one).
By splitting along the batch dim, we can partition the conditional logits and the unconditional ones:
conditional_logits, unconditional_logits = torch.split(logits, batch_size // 2)
-> we then perform our weighted sum over the conditional and unconditional logits for CFG.
Hope that explains how the 2x batch size trick works - would be keen to hear whether this aligns with how you've run CFG in your experiments.
Regarding implementing a new logits processor, we'd probably want to add this new logits processor when the time comes for integrating the model you've worked on into transformers, rather than adding it solely as a standalone logits processor. transformers is less of a modular toolbox for building new models, more a library for housing the most popular OS ML models
Have you trained a new model that uses this processor? Or built on-top of an existing one? (if it's the latter, then adding the CFG logits processor standalone makes sense, otherwise let's integrate it all in one go)
Thank you for your detailed answer @sanchit-gandhi !
The part I'm the most unclear with regarding the 2x batch trick is how the sampling happen. Do you actually sample the same continuation token for the conditional and unconditional branch, or do they diverge in their own direction (which would be weird imho)?
Regarding the integration, there is no need to train models to support CFG, it works out of the box. The paper will be out in few days, but as you can see on the figures, we employed it with LLaMA models, all Pythias, GPT-2 family, and even GPT4All. We don't train a new model. It's meant to be an addition to the .generate() method that is totally model agnostic and don't need training nor finetuning. Hence the PR with the standalone logits processor :)
Maybe this helps!
Pre-processing:
- conditional text ->
conditional_ids(bsz) - negative text ->
unconditional_ids(bsz) -
input_ids=[conditional_ids, unconditional_ids](2 * bsz since we've done a concat)
Forward pass:
-
logits(2 * bsz since they come from theinput_ids)
CFG:
-
conditional_logits,unconditional_logits=logits[:bsz],logits[bsz:](so each one is bsz since we've done a split) -
scores= weighted_sum(conditional_logits,unconditional_logits;guidance_scale) (bsz)
Sampling:
- next token = sample(
scores) (bsz num tokens -> we combined the cond/uncond logits to get the scores, so we only have bszscores, and thus bsz num tokens)
How have you been getting the conditional and unconditional logits in your experiments? Through two forward passes? (one with the conditional inputs and then a second with the unconditional ones). This batch size concatenation trick means you only have to run one forward pass, but with 2x the batch size
The only pain point I see with getting this work in transformers is this batch size change as we go from our forward pass to our sampling loop. But we can add some logic to change the batch size on the fly if we're doing CFG (kind of like we did for MusicGen @gante - we need to trick the forward pass into using 2 * bsz, then the decoder ids to use bsz).
here is no need to train models to support CFG, it works out of the box
Very cool indeed! Would be nice to have this as a standalone PR then as suggested
Thank you! Yeah if the cond and uncond prompts gets the same next token sampled, it's good wrt to our experiments! That's how you manage to loop around in the .generate() to grow the continuation token per token and zigzaging between bsz and 2bsz that I'm not 100% clear with. I totally see how it works for one forward pass. Totally an implementation detail :) But apparently that's a new trick you had to implement for MusicGen too so it makes sense that I'm not perfectly clear with that.
Would be nice to have this as a standalone PR then as suggested
I'm happy to address the changes that have to be made to contribute this into the lib :)
Awesome - feel free to open a PR and tag myself and @gante! How do you do it without the 2x batch size trick? Do you do two forward passes? Just asking in case there's a simpler way we can integrate this!
(catching up on the paper and thinking a bit about usage experience -- will comment tomorrow with specific suggestions, but I think @Vermeille's suggested implementation above will be pretty close to a great user experience with minimal compute overhead)
here is an alternative implementation we used for some of our other experiments in the paper, for your consideration.
it was designed with huggingface's typical *ModelFor* code-style in mind, which just puts the base model in the init and extends the forward() method
https://github.com/Vermeille/lm-evaluation-harness-cfg/blob/cfg-alex/log_logits_on_p3.py#L30-L97
Awesome - feel free to open a PR and tag myself and @gante! How do you do it without the 2x batch size trick? Do you do two forward passes? Just asking in case there's a simpler way we can integrate this!
Yes. Two consecutive passes. Which is indeed not that great wrt latency.
Would be great to have both the 2x batch size and two forward passes. Since 2x batch size is better for throughput but the two forward passes are much better for VRAM usage, as the Paper outlines
(unless I missunderstood)
So given you already have this ( https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py#L1070 )
What do you want me to add / change in the PR?
Would be great to have both the 2x batch size and two forward passes. Since 2x batch size is better for throughput but the two forward passes are much better for VRAM usage, as the Paper outlines
(unless I missunderstood)
This is correct: our focus was on getting the best results for a fixed amount of VRAM in our experiments. Hence it didn't occur to us to simply 2x the batch size. I agree that having this be togglable is a good idea and don't have any preference about the default.
The application to LLMs seems more of a situational sampling technique. With smaller conditional generative models like MusicGen, trained from-scratch with (explicit) condition dropout, it's practically part of the model. MusicGen isn't the first AR Transformer here, last year's DALL-E Mega already did it (itself inspired by https://twitter.com/RiversHaveWings/status/1478093658716966912 ), and in these models it's essential for performance.
So I'd expect "batch size 1 dramatically underutilizes available resources" to be the more common case.
Since 2x batch size is better for throughput but the two forward passes are much better for VRAM usage, as the Paper outlines
Depending on model and hardware, "biggest batch size that fits" isn't necessarily optimal. On decent hardware, you can hit optimal compute utilisation before VRAM limits with batched inference in smaller models.
Normalizing the summands, then interpolating with the original scores is intriguing. If adding this to the CFG implementation that's now in Transformers is still being considered, this would be unexpected as default behavior though. In diffusion models, it's not applicable, and in sequence prediction, I've only seen people combine the unnormalized scores.
@drdaxxy
Normalizing the summands, then interpolating with the original scores is intriguing. [...] In diffusion models, it's not applicable
This is a technique we borrowed from Common Diffusion Noise Schedules and Sample Steps are Flawed they call CFG Rescale. You can see Imagen doing some normalizing trick too.
in sequence prediction, I've only seen people combine the unnormalized scores.
That's what we started with, and our results were a little bit worse.
This method is interesting to implement from an engineering and maintenance point of view!
The simplest approach would be to proceed as @Vermeille suggested: add a logits processor that calls a model forward pass for the unconditional part of the input. It would be a small self-contained piece of code, which means low long-term maintenance on our end. On the negative side, we have the 2x latency, which is more impactful than the extra VRAM (IMO).
If we go the 2x batch size route, we need to implement a function like greedy_search or sample -- a long function with non-negligible maintenance costs on our end. I believe this would be the best form of CFG sampling. However, we are severely constrained by our ability to keep the machine up and running at a good pace, so we can quickly add new features like CFG sampling :D
We have a plan to reorganize generate such that it is entirely made of small functions, making it much more composable. In the way I'm envisioning it, the 2x batch size version of CFG sampling would need a few extra lines of code, as opposed to a new large function.
How about we go with @Vermeille's proposal now, which will make CFG sampling available this week with low overhead on our end, and we implement the 2x batch size version after the generate refactor is complete? The new logits processor class would need a different name, as we already have ClassifierFreeGuidanceLogitsProcessor for the 2x batch size case (perhaps UnbatchedClassifierFreeGuidanceLogitsProcessor?)
Expect a PR in few hours.
Thank you for your interest and answers!
@gante There is a name clash for the arguments to .generate(). For this PR, unless instructed otherwise before I submit it, cfg_scale (mine) will live next to guidance_scale (MusicGen's). Idk how to resolve this competition, give that .generate() does not seem ready to use the 2x batch trick yet.
@Vermeille Adding more (and partially redundant) parameterization is highly undesirable, and we'd want to favor the more general case (yours). You also have the additional requirement of renormalizing the logits before applying your logits processor. Fortunately, we haven't officially released a transformers version with MusicGen, so we still have some wiggle room!
Let's try to fit everything together -- here's my suggestion:
- your logits processor uses the same parameter,
guidance_scale, and it's triggered by its presence - EDIT: this is not needed ~your logits processor is added after the normalization one (after this if), and the normalization step is now also triggered when
guidance_scaleis non-None~ -
ClassifierFreeGuidanceLogitsProcessor(MusicGen's) is removed from the function that prepares the logits processors, and we modify MusicGen's generation function to handle its special processor: ifguidance_scaleis present when we generate withMusicGen, we pop it and manually add its CFG processor. I can take care of this part if you don't feel comfortable touchingMusicGen:)
This way the two strategies can coexist, share the argument, and not clash 🤗
Great! Thank you for the walkthrough.
On it.
Wait @gante, integrating it after the LogitNormalization is not something we want: all the prior processing (temperature, top_p, etc), will be used only on the conditional branch and not the unconditional, and will be executed before computing the CFG logits. To be fair, we haven't tested this transformation order, but being asymmetrical like this scares me.
And this is is even invalid. Top-k/p may not even select the same tokens in both branches, so that will misbehave.
I'm afraid I can't do that. CFG has to happen as one of the first logitprocessor
@Vermeille looking at your code example above, I didn't notice it already had normalization inside the processor. My bad -- feel free to add it as the 1st one :)
(will edit my comment above accordingly, for clarity)
So this is the code I got to get it working. It is just a hack but if you want to playwith it just use this code
from transformers import LogitsWarper
import torch
from torch.nn import functional as F
device = 'cpu'
if torch.has_cuda:
device = 'cuda'
class CFGLogits(LogitsWarper):
def __init__(self, cfg, inputs, model, verbose=True):
self.cfg = cfg
self.inputs = inputs
self.model = model
self.out = None
self.verbose = verbose
def __call__(self, input_ids, scores):
if self.cfg == 1:
return F.log_softmax(scores, dim=-1)
scores = F.log_softmax(scores, dim=-1)
if self.out is None:
self.out = self.model(self.inputs.to(device), use_cache=True)
else:
self.out = self.model(input_ids[:, -1:],
use_cache=True,
past_key_values=self.out.past_key_values)
unconditional_logits = F.log_softmax(self.out.logits[0][-1:], dim=-1)
out = self.cfg * (scores - unconditional_logits) + unconditional_logits
out = F.log_softmax(out, dim=-1)
return 0.7 * out + 0.3 * scores
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import LogitsProcessorList, TemperatureLogitsWarper, TopPLogitsWarper
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-160m")
model = AutoModelForCausalLM.from_pretrained("EleutherAI/pythia-160m")
prompt = "Salve, dispiculi."
inputs = tokenizer(prompt, return_tensors='pt')
model.to(device)
outputs = model.generate(
input_ids=inputs['input_ids'].to(device),
attention_mask=inputs['attention_mask'].to(device),
max_new_tokens=125,
logits_processor=LogitsProcessorList([
# inputs_cfg usually is the last token of the prompt but there are
# possibilities of negative prompting that are explored in the paper
CFGLogits(3, inputs['input_ids'], model),
TemperatureLogitsWarper(0.8),
TopPLogitsWarper(0.95),
]),
do_sample=True,
)
print(tokenizer.decode(outputs[0]))
This worked on my end
@grantCelley 's code works for me.
With CFG (pythia 160m)
Without CFG
@grantCelley @chris-aeviator
The line CFGLogits(3, inputs['input_ids'], model), should really be CFGLogits(3, inputs['input_ids'][:, -1:], model),
thanks for pointing it out, my 30 was a typo, but your prev. code doesnt seem to mention the [:, -1:] ?!
@chris-aeviator notice how it uses input_cfg:
# inputs_cfg usually is the last token of the prompt but there are
# possibilities of negative prompting that are explored in the paper
CFGLogits(cfg, inputs_cfg, model),