Add Watermarking LogitsProcessor and WatermarkDetector
What does this PR do?
Adds a watermarking technique proposed in this paper to transformers logits processor. I added only the simple method (algorithm 2 from paper) and the robust one (algorithm 3), both with context length of 1 token only. I am not sure if we should support higher context width, defined by user in generation config.
In contrast to the original repo, masking now is done in batched manner. Yet, I could not make the _get_greenlist_ids batched, so we are still left with a loop over batches one by one...
Note, this is only a processor for generation with watermarking. Anyone who uses it and wants later to detect the watermarked text, has to use the Detector from the original repo, using their own private hashing keys.
Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
- [ ] Did you read the contributor guideline, Pull Request section?
- [X] Was this discussed/approved via a Github issue or the forum? Please add a link to it if that's the case.
- [X] Did you make sure to update the documentation with your changes? Here are the documentation guidelines, and here are tips on formatting docstrings.
- [X] Did you write any new necessary tests?
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR.
@gante
cc @jonasgeiping @jwkirchenbauer
I would love to have your feedback on the default values we chose to use 🤗
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.
Hi, great work! Supporting a default context length of 4 would be something I would really advocate for, based on results such as Fig.2 in https://arxiv.org/abs/2306.04634.
A separate concern is execution speed. We hooked the watermark into cuda.rng.manual_seed, which is convenient, but not actually an efficient approach, and cannot be batched. A future-proof way of doing this would probably circumvent CUDA alltogether, and include in a different implementation of [list_of_integers + salt] -> hash -> pseudorandom green/red partition table, but we also didn't do that and I am not sure what your time frame for this feature is.
@JonasGeiping Thanks for the feedback! Yes, indeed the higher context width has better performance. I had reluctance adding more complexity to the code when opening PR, but now that we are okay with adding the whole watermarking functionality, I will add possibility for users to set their own context width.
Yes, different implementation of rng which works in batched form would be very nice to have. Right now I am not planning to work on it, and I prefer to leave it for future plans if we see active usage of the watermarking feature 😄
complexity could be reduced a bit by removing self-hashing for now. This setting has several implementation complexities, and without efficient RNG is quite slow to use it during text-gen for a purpose that is not testing watermark quality.
@gante , where can we add a doc for the detector? The tests are failing otherwise.
@zucchini-nlp here -- https://github.com/huggingface/transformers/blob/main/docs/source/en/internal/generation_utils.md
(ping me again when it's ready for a review :) )
@gante this should be ready for review
@zucchini-nlp ping me when it's ready for a re-review 🤗
@gante sorry, forgot to tag. Yes, ready to review. Added a config for watermark args, changed cache size and rewrote some docs.
@zucchini-nlp I'd also edit the PR header, it is outdated :) (for instance, it says users should use the detector from the original repo)
I'll review next week! 🤗
Hi, @zucchini-nlp
Running
RUN_SLOW=1 TF_FORCE_GPU_ALLOW_GROWTH=true python3 -m pytest -v tests/generation/test_utils.py::GenerationIntegrationTests::test_watermark_generation
gives
ValueError: The following `model_kwargs` are not used by the model: ['watermarking_args'] (note: typos in the generate arguments will also show up in this list)
Could you look into this?
Full error log
self = GPT2LMHeadModel(
(transformer): GPT2Model(
(wte): Embedding(50257, 768)
(wpe): Embedding(1024, 768)
(dro...((768,), eps=1e-05, elementwise_affine=True)
)
(lm_head): Linear(in_features=768, out_features=50257, bias=False)
)
model_kwargs = {'attention_mask': tensor([[1, 1, 1]], device='cuda:0'), 'input_ids': tensor([[ 40, 481, 307]], device='cuda:0'), 'watermarking_args': {'bias': 2.0, 'context_width': 1, 'greenlist_ratio': 0.25, 'hashing_key': 15485863, ...}}
def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]):
"""Validates model kwargs for generation. Generate argument typos will also be caught here."""
# If a `Cache` instance is passed, checks whether the model is compatible with it
if isinstance(model_kwargs.get("past_key_values", None), Cache) and not self._supports_cache_class:
raise ValueError(
f"{self.__class__.__name__} does not support an instance of `Cache` as `past_key_values`. Please "
"check the model documentation for supported cache formats."
)
# Excludes arguments that are handled before calling any model function
if self.config.is_encoder_decoder:
for key in ["decoder_input_ids"]:
model_kwargs.pop(key, None)
unused_model_args = []
model_args = set(inspect.signature(self.prepare_inputs_for_generation).parameters)
# `kwargs`/`model_kwargs` is often used to handle optional forward pass inputs like `attention_mask`. If
# `prepare_inputs_for_generation` doesn't accept them, then a stricter check can be made ;)
if "kwargs" in model_args or "model_kwargs" in model_args:
model_args |= set(inspect.signature(self.forward).parameters)
# Encoder-Decoder models may also need Encoder arguments from `model_kwargs`
if self.config.is_encoder_decoder:
base_model = getattr(self, self.base_model_prefix, None)
# allow encoder kwargs
encoder = getattr(self, "encoder", None)
# `MusicgenForConditionalGeneration` has `text_encoder` and `audio_encoder`.
# Also, it has `base_model_prefix = "encoder_decoder"` but there is no `self.encoder_decoder`
# TODO: A better way to handle this.
if encoder is None and base_model is not None:
encoder = getattr(base_model, "encoder", None)
if encoder is not None:
encoder_model_args = set(inspect.signature(encoder.forward).parameters)
model_args |= encoder_model_args
# allow decoder kwargs
decoder = getattr(self, "decoder", None)
if decoder is None and base_model is not None:
decoder = getattr(base_model, "decoder", None)
if decoder is not None:
decoder_model_args = set(inspect.signature(decoder.forward).parameters)
model_args |= {f"decoder_{x}" for x in decoder_model_args}
# allow assistant_encoder_outputs to be passed if we're doing assisted generating
if "assistant_encoder_outputs" in model_kwargs:
model_args |= {"assistant_encoder_outputs"}
for key, value in model_kwargs.items():
if value is not None and key not in model_args:
unused_model_args.append(key)
if unused_model_args:
> raise ValueError(
f"The following `model_kwargs` are not used by the model: {unused_model_args} (note: typos in the"
" generate arguments will also show up in this list)"
)
E ValueError: The following `model_kwargs` are not used by the model: ['watermarking_args'] (note: typos in the generate arguments will also show up in this list)
src/transformers/generation/utils.py:1136: ValueError
@ydshieh my bad, did not fix tests after latest changes
No worry. (But it is always nice to check the tests once we think a PR is ready at some point 😄 ) Thanks for the fixing!
The following objects docstrings do not match their signature. Run
make fix-copiesto fix this. In some cases, this error may be raised incorrectly by the docstring checker. If you think this is the case, you can manually check the docstrings
- WatermarkDetector
We need to check the docstrings for WatermarkDetector. You can first run make fix-copies but don't apply the changes blindly - check them and make a decision 🙏
@ArthurZucker ping
Thanks for the ping, reviewing now!
OUps on it today sorry @zucchini-nlp
The PR seems ready for me, all comments are addressed and the tests are passing. I see that I got two approvals, but I will leave it here until next week May 13 in case anyone wants to add something 😄