NeMo icon indicating copy to clipboard operation
NeMo copied to clipboard

Davidm/gpt sft dataset fix

Open Davood-M opened this issue 2 years ago • 6 comments

What does this PR do ?

When using small prompts, the GPT_SFT_DATASET generated input prompts because of this logic:

if len(ids) < truncation_length:
         logging.warning(f'{key} is not long enough to truncate.')
         truncation_length = len(ids)

if self.truncation_method == 'left':
         window_offset = truncation_length
elif self.truncation_method == 'right':
         window_offset = 0
else:
         raise ValueError(f'{self.truncation_method} is not supported')

window_length = len(ids) - truncation_length
template_ids[i] = ids[window_offset : window_offset + window_length]

In this part of code, if the truncation method is right and the prompt len is less than truncation_length, then the window_length would become 0 and the template_ids would become empty! Tested this on tuned llama2 models and it seems to solved the issue.

Collection: [Note which collection this PR will affect]

Changelog

  • Updated the mentioned part for both left and right truncations

Usage

  • You can potentially add a usage example below
# Add a code snippet demonstrating how to use this 

Jenkins CI

To run Jenkins, a NeMo User with write access must comment jenkins on the PR.

Before your PR is "Ready for review"

Pre checks:

  • [x] Make sure you read and followed Contributor guidelines
  • [ ] Did you write any new necessary tests?
  • [ ] Did you add or update any necessary documentation?
  • [ ] Does the PR affect components that are optional to install? (Ex: Numba, Pynini, Apex etc)
    • [ ] Reviewer: Does the PR have correct import guards for all optional libraries?

PR Type:

  • [ ] New Feature
  • [x] Bugfix
  • [ ] Documentation

If you haven't finished some of the above items you can still open "Draft" PR.

Who can review?

Anyone in the community is free to review the PR once the checks have passed. Contributor guidelines contains specific people who can review PRs to various areas.

Additional Information

  • Related to # (issue)

Davood-M avatar Dec 10 '23 23:12 Davood-M

This PR is stale because it has been open for 14 days with no activity. Remove stale label or comment or update or this will be closed in 7 days.

github-actions[bot] avatar Dec 25 '23 01:12 github-actions[bot]

This PR was closed because it has been inactive for 7 days since being marked as stale.

github-actions[bot] avatar Jan 01 '24 01:01 github-actions[bot]

@Davood-M I am having the same error as input is not long enough in my SFT experiment. Do you know if its fixed in latest version?

premmotgi avatar Jan 31 '24 21:01 premmotgi

jenkins

Davood-M avatar May 01 '24 23:05 Davood-M

Hi @Davood-M, if your entire input length (template + placeholder strings) > sequence length, then we'll need to do the truncation. We truncate only the truncation_field part in your placeholder strings. If the truncation_field is not long enough to truncate, then we simply make this field to be empty. This situation may happen if the length of your template > the length of your placeholder strings. I think the solutions are either change your template or increase sequence length.

hsiehjackson avatar May 01 '24 23:05 hsiehjackson

This PR is stale because it has been open for 14 days with no activity. Remove stale label or comment or update or this will be closed in 7 days.

github-actions[bot] avatar May 16 '24 01:05 github-actions[bot]

HI, What is the latest non-dev container that does not have this issue? I'm guessing this is a recent issue

aditya-malte avatar May 22 '24 00:05 aditya-malte

This error can be resolved by either 1. setting the truncation_field="output" or 2. setting truncation_field=None/null.

arendu avatar May 23 '24 03:05 arendu

Does that resolve the write_predictions_to_file bug?

aditya-malte avatar May 24 '24 00:05 aditya-malte

Hi @arendu I tried the solution you suggested but get a new error

11: Set the environment variable HYDRA_FULL_ERROR=1 for a complete stack trace.
 9: Traceback (most recent call last):
 9:   File "/opt/NeMo/examples/nlp/language_modeling/tuning/megatron_gpt_generate.py", line 164, in main
 9:     trainer.test(model)
 9:   File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/trainer.py", line 754, in test
 9:     return call._call_and_handle_interrupt(
 9:   File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/call.py", line 43, in _call_and_handle_interrupt
 9:     return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
 9:   File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/strategies/launchers/subprocess_script.py", line 105, in launch
 9:     return function(*args, **kwargs)
 9:   File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/trainer.py", line 794, in _test_impl
 9:     results = self._run(model, ckpt_path=ckpt_path)
 9:   File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/trainer.py", line 987, in _run
 9:     results = self._run_stage()
 9:   File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/trainer.py", line 1026, in _run_stage
 9:     return self._evaluation_loop.run()
 9:   File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/utilities.py", line 182, in _decorator
 9:     return loop_run(self, *args, **kwargs)
 9:   File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/evaluation_loop.py", line 135, in run
 9:     self._evaluation_step(batch, batch_idx, dataloader_idx, dataloader_iter)
 9:   File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/evaluation_loop.py", line 396, in _evaluation_step
 9:     output = call._call_strategy_hook(trainer, hook_name, *step_args)
 9:   File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/call.py", line 309, in _call_strategy_hook
 9:     output = fn(*args, **kwargs)
 9:   File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/strategies/strategy.py", line 425, in test_step
 9:     return self.lightning_module.test_step(*args, **kwargs)
 9:   File "/opt/NeMo/nemo/collections/nlp/models/language_modeling/megatron_gpt_sft_model.py", line 434, in test_step
 9:     return self.inference_step(dataloader_iter, 'test')
 9:   File "/opt/NeMo/nemo/collections/nlp/models/language_modeling/megatron_gpt_sft_model.py", line 437, in inference_step
 9:     batch, batch_idx, dataloader_idx = next(dataloader_iter)
 9:   File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/fetchers.py", line 207, in __next__
 9:     batch, batch_idx, dataloader_idx = super(_DataLoaderIterDataFetcher, fetcher).__next__()
 9:   File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/fetchers.py", line 60, in __next__
 9:     batch = next(self.iterator)
 9:   File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/utilities/combined_loader.py", line 341, in __next__
 9:     out = next(self._iterator)
 9:   File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/utilities/combined_loader.py", line 142, in __next__
 9:     out = next(self.iterators[0])
 9:   File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 631, in __next__
 9:     data = self._next_data()
 9:   File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 675, in _next_data
 9:     data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
 9:   File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/_utils/fetch.py", line 54, in fetch
 9:     return self.collate_fn(data)
 9:   File "/opt/NeMo/nemo/collections/nlp/data/language_modeling/megatron/gpt_sft_dataset.py", line 479, in collate_fn
 9:     contexts = torch.LongTensor(self._collate_item(contexts, max_length=max_length, pad_id=self.tokenizer.eos_id))
 9: ValueError: expected sequence of length 4096 at dim 1 (got 4512)
 9: 

aditya-malte avatar May 24 '24 00:05 aditya-malte

Hi @aditya-malte, in your setup what values are your truncation_field and prompt_template?

hsiehjackson avatar May 24 '24 17:05 hsiehjackson