transformers icon indicating copy to clipboard operation
transformers copied to clipboard

Add Classifier-Free Guidance sampling

Open Vermeille opened this issue 2 years ago • 6 comments

Feature request

Hello! I wish to contribute CFG sampling. I'm working with EleutherAI and @StellaAthena and will have a paper about it by Friday. CFG brings non trivial improvements on many standard benchmarks. It contrast the logits for the next token $P(w_t|w_{..t}, prompt)$ to that of the input deprived of the prompt $P(w_t|w_{..t})$, by defining

$$ \log P_{\text{cfg}}(w|w_{..t}, prompt) = \log P(w|w_{..t}) + \text{cfg} * (\log P(w|w_{..t}, prompt) - \log P(w|w_{..t}) $$

And then we can blend $\log P_{\text{cfg}}$ with $\log P(w|w_{..t}, prompt)$ to smoothen that distribution a bit, but it's optional.

Motivation

My current implementation is:

class CFGLogits(LogitsWarper):

    def __init__(self, cfg, inputs, model, verbose=True):
        self.cfg = cfg
        self.inputs = inputs
        self.model = model
        self.out = None
        self.verbose = verbose

    def __call__(self, input_ids, scores):
        if self.cfg == 1:
            return F.log_softmax(scores, dim=-1)
        scores = F.log_softmax(scores, dim=-1)
        if self.out is None:
            self.out = self.model(self.inputs.to(device), use_cache=True)
        else:
            self.out = self.model(input_ids[:, -1:],
                                  use_cache=True,
                                  past_key_values=self.out.past_key_values)
        unconditional_logits = F.log_softmax(self.out.logits[0][-1:], dim=-1)
        out = self.cfg * (scores - unconditional_logits) + unconditional_logits
        out = F.log_softmax(out, dim=-1)
        return 0.7 * out + 0.3 * scores

# usage:

outputs = model.generate(
    input_ids=inputs['input_ids'].to(device),
    attention_mask=inputs['attention_mask'].to(device),
    max_new_tokens=l,
    logits_processor=LogitsProcessorList([
        # inputs_cfg usually is the last token of the prompt but there are
        # possibilities of negative prompting that are explored in the paper
        CFGLogits(cfg, inputs_cfg, model),
        TemperatureLogitsWarper(0.8),
        TopPLogitsWarper(0.95),
    ]),
    do_sample=True,
)

I am not familiar enough with the design guidelines of HF to know if this implementation as a LogitsWarper is satisfactory.

just a few figures supporting the claims: flops image image

image image

Your contribution

I can contribute the code but I need to be guided as I don't know the exact design guidelines and overall architecture of HF.

Thank you for your time!

Vermeille avatar Jun 28 '23 02:06 Vermeille

cc @gante But let's see if the community requests this added feature before implementing it in the library proper :-)

sgugger avatar Jun 28 '23 12:06 sgugger

Hey @Vermeille 👋

I have the impression that our MusicGen PR (still open, expected to get merged soon) introduces the bulk of the logic to make it happen -- see this file

It is the same thing with a slightly different code implementation, correct? In the MusicGen PR, the model does a forward pass with 2x the batch size, where half of the batch corresponds to the unprompted tokens

gante avatar Jun 28 '23 14:06 gante

Indeed @gante !

I don't fully get how the 2x batch size thing works, but if it does, it's cool. The paper makes some more additions to that base implementation:

  1. the uncond_logits might in fact have a different prompt than the cond_logits, which is commonly called "negative prompt".
  2. the comment says "usually at the expense of poorer quality". This can be mitigated with linearly interpolating the cfg scores back with with the initial scores
  3. We had better results log_softmaxing both scores before cfg, which normalizes both logits sets to a common "scale".

Vermeille avatar Jun 28 '23 15:06 Vermeille

cc @sanchit-gandhi, who's probably better equipped to comment on potential differences :)

gante avatar Jun 29 '23 10:06 gante

Hey @Vermeille - thanks for the comprehensive write-up! Just a clarifying question: in your implementation, how do you construct the token ids for the model based on the conditional ids and the un-conditional ones? You mention:

inputs_cfg usually is the last token of the prompt but there are

Which suggests you concatenate them together in the same batch item?

In MusicGen (and also the HF Diffusers library for models like Stable Diffusion), we construct our input ids by concatenating the input ids for the conditional prompt and the un-conditional prompt along the batch dimension (dim=0):

input_ids = torch.concatenate([conditional_ids, unconditional_ids], dim=0)

This is what's referred to by the 2x batch size 'trick' (concatenating the conditional prompt and unconditional prompt over the batch dim). There's no restriction to how these unconditional ids are formed - they can be from a 'null' input, or from a negative prompt. So we can do negative prompting in exactly the way you've described.

When we run our model forward, the logits for the first half of the batch corresponds to the conditional prompt, and the second half to the unconditional prompt (or negative prompt if we use one).

By splitting along the batch dim, we can partition the conditional logits and the unconditional ones:

conditional_logits, unconditional_logits = torch.split(logits, batch_size // 2)

-> we then perform our weighted sum over the conditional and unconditional logits for CFG.

Hope that explains how the 2x batch size trick works - would be keen to hear whether this aligns with how you've run CFG in your experiments.

Regarding implementing a new logits processor, we'd probably want to add this new logits processor when the time comes for integrating the model you've worked on into transformers, rather than adding it solely as a standalone logits processor. transformers is less of a modular toolbox for building new models, more a library for housing the most popular OS ML models

Have you trained a new model that uses this processor? Or built on-top of an existing one? (if it's the latter, then adding the CFG logits processor standalone makes sense, otherwise let's integrate it all in one go)

sanchit-gandhi avatar Jun 30 '23 15:06 sanchit-gandhi

Thank you for your detailed answer @sanchit-gandhi !

The part I'm the most unclear with regarding the 2x batch trick is how the sampling happen. Do you actually sample the same continuation token for the conditional and unconditional branch, or do they diverge in their own direction (which would be weird imho)?

Regarding the integration, there is no need to train models to support CFG, it works out of the box. The paper will be out in few days, but as you can see on the figures, we employed it with LLaMA models, all Pythias, GPT-2 family, and even GPT4All. We don't train a new model. It's meant to be an addition to the .generate() method that is totally model agnostic and don't need training nor finetuning. Hence the PR with the standalone logits processor :)

Vermeille avatar Jun 30 '23 23:06 Vermeille

The paper is out

Vermeille avatar Jul 03 '23 10:07 Vermeille

Maybe this helps!

Pre-processing:

  • conditional text -> conditional_ids (bsz)
  • negative text -> unconditional_ids (bsz)
  • input_ids = [conditional_ids, unconditional_ids] (2 * bsz since we've done a concat)

Forward pass:

  • logits (2 * bsz since they come from the input_ids)

CFG:

  • conditional_logits, unconditional_logits = logits[:bsz], logits[bsz:] (so each one is bsz since we've done a split)
  • scores = weighted_sum(conditional_logits, unconditional_logits; guidance_scale) (bsz)

Sampling:

  • next token = sample(scores) (bsz num tokens -> we combined the cond/uncond logits to get the scores, so we only have bsz scores, and thus bsz num tokens)

How have you been getting the conditional and unconditional logits in your experiments? Through two forward passes? (one with the conditional inputs and then a second with the unconditional ones). This batch size concatenation trick means you only have to run one forward pass, but with 2x the batch size

The only pain point I see with getting this work in transformers is this batch size change as we go from our forward pass to our sampling loop. But we can add some logic to change the batch size on the fly if we're doing CFG (kind of like we did for MusicGen @gante - we need to trick the forward pass into using 2 * bsz, then the decoder ids to use bsz).

here is no need to train models to support CFG, it works out of the box

Very cool indeed! Would be nice to have this as a standalone PR then as suggested

sanchit-gandhi avatar Jul 03 '23 15:07 sanchit-gandhi

Thank you! Yeah if the cond and uncond prompts gets the same next token sampled, it's good wrt to our experiments! That's how you manage to loop around in the .generate() to grow the continuation token per token and zigzaging between bsz and 2bsz that I'm not 100% clear with. I totally see how it works for one forward pass. Totally an implementation detail :) But apparently that's a new trick you had to implement for MusicGen too so it makes sense that I'm not perfectly clear with that.

Would be nice to have this as a standalone PR then as suggested

I'm happy to address the changes that have to be made to contribute this into the lib :)

Vermeille avatar Jul 03 '23 15:07 Vermeille

Awesome - feel free to open a PR and tag myself and @gante! How do you do it without the 2x batch size trick? Do you do two forward passes? Just asking in case there's a simpler way we can integrate this!

sanchit-gandhi avatar Jul 03 '23 15:07 sanchit-gandhi

(catching up on the paper and thinking a bit about usage experience -- will comment tomorrow with specific suggestions, but I think @Vermeille's suggested implementation above will be pretty close to a great user experience with minimal compute overhead)

gante avatar Jul 03 '23 17:07 gante

here is an alternative implementation we used for some of our other experiments in the paper, for your consideration.

it was designed with huggingface's typical *ModelFor* code-style in mind, which just puts the base model in the init and extends the forward() method https://github.com/Vermeille/lm-evaluation-harness-cfg/blob/cfg-alex/log_logits_on_p3.py#L30-L97

alex2awesome avatar Jul 03 '23 17:07 alex2awesome

Awesome - feel free to open a PR and tag myself and @gante! How do you do it without the 2x batch size trick? Do you do two forward passes? Just asking in case there's a simpler way we can integrate this!

Yes. Two consecutive passes. Which is indeed not that great wrt latency.

Vermeille avatar Jul 03 '23 21:07 Vermeille

Would be great to have both the 2x batch size and two forward passes. Since 2x batch size is better for throughput but the two forward passes are much better for VRAM usage, as the Paper outlines

(unless I missunderstood)

elikoga avatar Jul 03 '23 21:07 elikoga

So given you already have this ( https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py#L1070 )

What do you want me to add / change in the PR?

Vermeille avatar Jul 03 '23 23:07 Vermeille

Would be great to have both the 2x batch size and two forward passes. Since 2x batch size is better for throughput but the two forward passes are much better for VRAM usage, as the Paper outlines

(unless I missunderstood)

This is correct: our focus was on getting the best results for a fixed amount of VRAM in our experiments. Hence it didn't occur to us to simply 2x the batch size. I agree that having this be togglable is a good idea and don't have any preference about the default.

StellaAthena avatar Jul 03 '23 23:07 StellaAthena

The application to LLMs seems more of a situational sampling technique. With smaller conditional generative models like MusicGen, trained from-scratch with (explicit) condition dropout, it's practically part of the model. MusicGen isn't the first AR Transformer here, last year's DALL-E Mega already did it (itself inspired by https://twitter.com/RiversHaveWings/status/1478093658716966912 ), and in these models it's essential for performance.

So I'd expect "batch size 1 dramatically underutilizes available resources" to be the more common case.

Since 2x batch size is better for throughput but the two forward passes are much better for VRAM usage, as the Paper outlines

Depending on model and hardware, "biggest batch size that fits" isn't necessarily optimal. On decent hardware, you can hit optimal compute utilisation before VRAM limits with batched inference in smaller models.


Normalizing the summands, then interpolating with the original scores is intriguing. If adding this to the CFG implementation that's now in Transformers is still being considered, this would be unexpected as default behavior though. In diffusion models, it's not applicable, and in sequence prediction, I've only seen people combine the unnormalized scores.

drdaxxy avatar Jul 04 '23 03:07 drdaxxy

@drdaxxy

Normalizing the summands, then interpolating with the original scores is intriguing. [...] In diffusion models, it's not applicable

This is a technique we borrowed from Common Diffusion Noise Schedules and Sample Steps are Flawed they call CFG Rescale. You can see Imagen doing some normalizing trick too.

in sequence prediction, I've only seen people combine the unnormalized scores.

That's what we started with, and our results were a little bit worse.

Vermeille avatar Jul 04 '23 08:07 Vermeille

This method is interesting to implement from an engineering and maintenance point of view!

The simplest approach would be to proceed as @Vermeille suggested: add a logits processor that calls a model forward pass for the unconditional part of the input. It would be a small self-contained piece of code, which means low long-term maintenance on our end. On the negative side, we have the 2x latency, which is more impactful than the extra VRAM (IMO).

If we go the 2x batch size route, we need to implement a function like greedy_search or sample -- a long function with non-negligible maintenance costs on our end. I believe this would be the best form of CFG sampling. However, we are severely constrained by our ability to keep the machine up and running at a good pace, so we can quickly add new features like CFG sampling :D

We have a plan to reorganize generate such that it is entirely made of small functions, making it much more composable. In the way I'm envisioning it, the 2x batch size version of CFG sampling would need a few extra lines of code, as opposed to a new large function.

How about we go with @Vermeille's proposal now, which will make CFG sampling available this week with low overhead on our end, and we implement the 2x batch size version after the generate refactor is complete? The new logits processor class would need a different name, as we already have ClassifierFreeGuidanceLogitsProcessor for the 2x batch size case (perhaps UnbatchedClassifierFreeGuidanceLogitsProcessor?)

gante avatar Jul 04 '23 09:07 gante

Expect a PR in few hours.

Thank you for your interest and answers!

Vermeille avatar Jul 04 '23 09:07 Vermeille

@gante There is a name clash for the arguments to .generate(). For this PR, unless instructed otherwise before I submit it, cfg_scale (mine) will live next to guidance_scale (MusicGen's). Idk how to resolve this competition, give that .generate() does not seem ready to use the 2x batch trick yet.

Vermeille avatar Jul 04 '23 12:07 Vermeille

@Vermeille Adding more (and partially redundant) parameterization is highly undesirable, and we'd want to favor the more general case (yours). You also have the additional requirement of renormalizing the logits before applying your logits processor. Fortunately, we haven't officially released a transformers version with MusicGen, so we still have some wiggle room!

Let's try to fit everything together -- here's my suggestion:

  • your logits processor uses the same parameter, guidance_scale, and it's triggered by its presence
  • EDIT: this is not needed ~your logits processor is added after the normalization one (after this if), and the normalization step is now also triggered when guidance_scale is non-None~
  • ClassifierFreeGuidanceLogitsProcessor (MusicGen's) is removed from the function that prepares the logits processors, and we modify MusicGen's generation function to handle its special processor: if guidance_scale is present when we generate with MusicGen, we pop it and manually add its CFG processor. I can take care of this part if you don't feel comfortable touching MusicGen :)

This way the two strategies can coexist, share the argument, and not clash 🤗

gante avatar Jul 04 '23 13:07 gante

Great! Thank you for the walkthrough.

On it.

Vermeille avatar Jul 04 '23 13:07 Vermeille

Wait @gante, integrating it after the LogitNormalization is not something we want: all the prior processing (temperature, top_p, etc), will be used only on the conditional branch and not the unconditional, and will be executed before computing the CFG logits. To be fair, we haven't tested this transformation order, but being asymmetrical like this scares me.

And this is is even invalid. Top-k/p may not even select the same tokens in both branches, so that will misbehave.

I'm afraid I can't do that. CFG has to happen as one of the first logitprocessor

Vermeille avatar Jul 04 '23 13:07 Vermeille

@Vermeille looking at your code example above, I didn't notice it already had normalization inside the processor. My bad -- feel free to add it as the 1st one :)

(will edit my comment above accordingly, for clarity)

gante avatar Jul 04 '23 13:07 gante

So this is the code I got to get it working. It is just a hack but if you want to playwith it just use this code

from transformers import LogitsWarper
import torch
from torch.nn import functional as F

device = 'cpu'
if torch.has_cuda:
    device = 'cuda'

class CFGLogits(LogitsWarper):

    def __init__(self, cfg, inputs, model, verbose=True):
        self.cfg = cfg
        self.inputs = inputs
        self.model = model
        self.out = None
        self.verbose = verbose

    def __call__(self, input_ids, scores):
        if self.cfg == 1:
            return F.log_softmax(scores, dim=-1)
        scores = F.log_softmax(scores, dim=-1)
        if self.out is None:
            self.out = self.model(self.inputs.to(device), use_cache=True)
        else:
            self.out = self.model(input_ids[:, -1:],
                                  use_cache=True,
                                  past_key_values=self.out.past_key_values)
        unconditional_logits = F.log_softmax(self.out.logits[0][-1:], dim=-1)
        out = self.cfg * (scores - unconditional_logits) + unconditional_logits
        out = F.log_softmax(out, dim=-1)
        return 0.7 * out + 0.3 * scores
    
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import LogitsProcessorList, TemperatureLogitsWarper, TopPLogitsWarper

tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-160m")

model = AutoModelForCausalLM.from_pretrained("EleutherAI/pythia-160m")

prompt = "Salve, dispiculi."
inputs = tokenizer(prompt, return_tensors='pt')
model.to(device)
outputs = model.generate(
    input_ids=inputs['input_ids'].to(device),
    attention_mask=inputs['attention_mask'].to(device),
    max_new_tokens=125,
    logits_processor=LogitsProcessorList([
        # inputs_cfg usually is the last token of the prompt but there are
        # possibilities of negative prompting that are explored in the paper
        CFGLogits(3, inputs['input_ids'], model),
        TemperatureLogitsWarper(0.8),
        TopPLogitsWarper(0.95),
    ]),
    do_sample=True,
)

print(tokenizer.decode(outputs[0]))

This worked on my end

grantCelley avatar Jul 05 '23 04:07 grantCelley

@grantCelley 's code works for me.

With CFG (pythia 160m)

grafik

Without CFG

grafik

chris-aeviator avatar Jul 05 '23 09:07 chris-aeviator

@grantCelley @chris-aeviator The line CFGLogits(3, inputs['input_ids'], model), should really be CFGLogits(3, inputs['input_ids'][:, -1:], model),

Vermeille avatar Jul 05 '23 10:07 Vermeille

thanks for pointing it out, my 30 was a typo, but your prev. code doesnt seem to mention the [:, -1:] ?!

chris-aeviator avatar Jul 05 '23 10:07 chris-aeviator

@chris-aeviator notice how it uses input_cfg:

        # inputs_cfg usually is the last token of the prompt but there are
        # possibilities of negative prompting that are explored in the paper
        CFGLogits(cfg, inputs_cfg, model),

Vermeille avatar Jul 05 '23 10:07 Vermeille