Problems when generate continuations
Hi! I try to generate continuations with DExperts using the following scrpit.
OUTPUT_DIR=output/toxicity/dexperts
PROMPTS_DATASET=prompts/nontoxic_prompts-10k.jsonl
MODEL_DIR=models/experts/toxicity/large
API_RATE=25
export CUDA_VISIBLE_DEVICES=6
python -m scripts.run_toxicity_experiment \
--use-dataset \
--dataset-file $PROMPTS_DATASET \
--model-type dexperts \
--model gpt2-large \
--perspective-rate-limit $API_RATE \
--nontoxic-model $MODEL_DIR/finetuned_gpt2_nontoxic \
--toxic-model $MODEL_DIR/finetuned_gpt2_toxic \
--alpha 2.0 \
--filter_p 0.9 \
$OUTPUT_DIR
But there is a problem I don't understand. The following is the error message.
Prompts:
0 For more information on what is currently avai...
1 Instead, these women watched the man they say ...
2 Because what really matters is that you are in...
3 There's another controversial Hollywood racial...
4 That's compared to the $20 million that San Fr...
...
9995 Get The Start-Up Israel's Daily Start-Up
9996 There was once a tiny rectangle of land, devoi...
9997 Michael Hanley, Bishop of Oregon
9998 Seriously, how in the world did the
9999 The 21-year-old attacker, who was not immediat...
Name: text, Length: 10000, dtype: object
Process Process-1:
Traceback (most recent call last):
File "/data0/xp/anaconda3/envs/dexperts/lib/python3.8/multiprocessing/process.py", line 315, in _bootstrap
self.run()
File "/data0/xp/anaconda3/envs/dexperts/lib/python3.8/multiprocessing/process.py", line 108, in run
self._target(*self._args, **self._kwargs)
File "/data0/xp/ctg/DExperts/utils/perspective_api.py", line 168, in perspective_worker
api = PerspectiveAPI(rate_limit=rate_limit)
File "/data0/xp/ctg/DExperts/utils/perspective_api.py", line 42, in __init__
self.service = self._make_service(api_key)
File "/data0/xp/ctg/DExperts/utils/perspective_api.py", line 117, in _make_service
return discovery.build('commentanalyzer', 'v1alpha1', developerKey=api_key)
File "/data0/xp/anaconda3/envs/dexperts/lib/python3.8/site-packages/googleapiclient/_helpers.py", line 131, in positional_wrapper
return wrapped(*args, **kwargs)
File "/data0/xp/anaconda3/envs/dexperts/lib/python3.8/site-packages/googleapiclient/discovery.py", line 287, in build
content = _retrieve_discovery_doc(
File "/data0/xp/anaconda3/envs/dexperts/lib/python3.8/site-packages/googleapiclient/discovery.py", line 404, in _retrieve_discovery_doc
raise UnknownApiNameOrVersion("name: %s version: %s" % (serviceName, version))
googleapiclient.errors.UnknownApiNameOrVersion: name: commentanalyzer version: v1alpha1
Generation: 0%| | 0/7813 [00:00<?, ?it/s, batch_size=32]/data0/xp/anaconda3/envs/dexperts/lib/python3.8/site-packages/transformers/tokenization_utils_base.py:2263: FutureWarning: The `pad_to_max_length` argument is deprecated and will be removed in a future version, use `padding=True` or `padding='longest'` to pad to the longest sequence in the batch, or use `padding='max_length'` to pad to a max length. In this case, you can give a specific length with `max_length` (e.g. `max_length=45`) or leave max_length to None to pad to the maximal input size of the model (e.g. 512 for Bert).
warnings.warn(
Generation: 0%| | 0/7813 [00:00<?, ?it/s, batch_size=32]
Traceback (most recent call last):
File "/data0/xp/anaconda3/envs/dexperts/lib/python3.8/runpy.py", line 194, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/data0/xp/anaconda3/envs/dexperts/lib/python3.8/runpy.py", line 87, in _run_code
exec(code, run_globals)
File "/data0/xp/ctg/DExperts/scripts/run_toxicity_experiment.py", line 187, in <module>
main()
File "/data0/xp/anaconda3/envs/dexperts/lib/python3.8/site-packages/click/core.py", line 1128, in __call__
return self.main(*args, **kwargs)
File "/data0/xp/anaconda3/envs/dexperts/lib/python3.8/site-packages/click/core.py", line 1053, in main
rv = self.invoke(ctx)
File "/data0/xp/anaconda3/envs/dexperts/lib/python3.8/site-packages/click/core.py", line 1395, in invoke
return ctx.invoke(self.callback, **ctx.params)
File "/data0/xp/anaconda3/envs/dexperts/lib/python3.8/site-packages/click/core.py", line 754, in invoke
return __callback(*args, **kwargs)
File "/data0/xp/ctg/DExperts/scripts/run_toxicity_experiment.py", line 173, in main
for i, gen in enumerate(generations_iter):
File "/data0/xp/ctg/DExperts/generation/generation.py", line 202, in dexperts
yield from _gpt2_helper(
File "/data0/xp/ctg/DExperts/generation/generation.py", line 159, in _gpt2_helper
batch = generator.generate(prompt, max_len, **generate_kwargs)
File "/data0/xp/ctg/DExperts/generation/dexperts_generation.py", line 96, in generate
base_logits = top_k_top_p_filtering(base_logits, top_p=filter_p)
File "/data0/xp/ctg/DExperts/utils/generation_utils.py", line 29, in top_k_top_p_filtering
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
TypeError: sort() received an invalid combination of arguments - got (str, descending=bool), but expected one of:
* (Tensor input, *, bool stable, int dim, bool descending, tuple of Tensors out)
* (Tensor input, int dim, bool descending, *, tuple of Tensors out)
* (Tensor input, *, bool stable, name dim, bool descending, tuple of Tensors out)
* (Tensor input, name dim, bool descending, *, tuple of Tensors out)
In short, the input of torch.sort function should be Tensor, but string logits is input, and I don't know why that is. I look forward to your reply. Thank you.
@Richard88888 I guess the problem is the unpacking of model output is not done properly here in code.
This is happening because variable logits is getting strings, not tensors due to the incorrect way of unpacking.
Someone already asked a similar question here: https://github.com/huggingface/transformers/issues/8919
A quick fix that worked for me: call the to_tuple() method before unpacking model outputs.
@alisawuffles Shall I send PR for this?
Hey @Kadam-Tushar, thanks for the fix! This codebase was based on transformers==3.3.1, and I believe this error may be due to a breaking change from transformers 3.x to 4.x with regards to how model output can be unpacked. Sure, please feel free to create a pull request!