RuntimeError for contrastive loss if audio duration is not long enough for num_negatives
Describe the bug
If the duration of an audio file is <8s (unsure exactly, need further testing), then sample_negatives will throw the following error:
File "/opt/conda/lib/python3.8/site-packages/nemo/collections/asr/models/ssl_models.py", line 491, in validation_step
loss_value, _ = self.decoder_loss_step(spectrograms, spec_masks, encoded, encoded_len, targets, target_lengths)
File "/opt/conda/lib/python3.8/site-packages/nemo/collections/asr/models/ssl_models.py", line 450, in decoder_loss_step
current_loss_value = current_loss(
File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1129, in _call_impl
return forward_call(*input, **kwargs)
File "/opt/conda/lib/python3.8/site-packages/nemo/core/classes/common.py", line 963, in __call__
outputs = wrapped(*args, **kwargs)
File "/opt/conda/lib/python3.8/site-packages/nemo/collections/asr/losses/ssl_losses/contrastive.py", line 205, in forward
negatives, _ = self.sample_negatives(targets_masked_only, targets_masked_only.size(0)) # T'xBxC # T'
File "/opt/conda/lib/python3.8/site-packages/nemo/collections/asr/losses/ssl_losses/contrastive.py", line 139, in sample_negatives
neg_idxs = torch.multinomial(torch.ones((num, high), device=y.device), self.num_negatives)
RuntimeError: cannot sample n_sample > prob_dist.size(-1) samples without replacement
If you haven't read the docs or understand the code/loss function, then it might not be clear that you need to adjust the min_duration or num_negatives
Steps/Code to reproduce bug
Run any of the examples in speech_pretraining with train/validation_ds.min_duration < 8.0
Expected behavior
While this is expected to fail, I would like to propose that we catch this error and replace with something more...constructive. I can open a PR, but wanted to see if the maintainers are open to the suggestion.
My suggestion is to re-raise with a NeMoBaseException (or make a new one? LossMisconfigurationError?) with the suggestion to either adjust num_negative or the min_duration.
Edit: I realized I opened this as a bug, but it should've been a feature suggestion.
@piraka9011 I think this would be a useful thing to catch, please make a PR for your suggestion
This issue is stale because it has been open for 60 days with no activity.
This issue was closed because it has been inactive for 7 days since being marked as stale.
Hi, I'm facing the same issue with the default conformer config(changed min_duration=1.6 and max_duration=30.0).
File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/batch/training_batch_loop.py", line 88, in advance
outputs = self.optimizer_loop.run(optimizers, kwargs)
File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/loop.py", line 199, in run
self.advance(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 200, in advance
result = self._run_optimization(kwargs, self._optimizers[self.optim_progress.optimizer_position])
File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 239, in _run_optimization
closure()
File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 147, in __call__
self._result = self.closure(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 133, in closure
step_output = self._step_fn()
File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 406, in _training_step
training_step_output = self.trainer._call_strategy_hook("training_step", *kwargs.values())
File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/trainer.py", line 1480, in _call_strategy_hook
output = fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/strategies/ddp.py", line 352, in training_step
return self.model(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/parallel/distributed.py", line 1040, in forward
output = self._run_ddp_forward(*inputs, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/parallel/distributed.py", line 1000, in _run_ddp_forward
return module_to_run(*inputs[0], **kwargs[0])
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/overrides/base.py", line 98, in forward
output = self._forward_module.training_step(*inputs, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/nemo/utils/model_utils.py", line 362, in wrap_training_step
output_dict = wrapped(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/nemo/collections/asr/models/ssl_models.py", line 495, in training_step
loss_value, loss_val_dict = self.decoder_loss_step(
File "/usr/local/lib/python3.10/dist-packages/nemo/collections/asr/models/ssl_models.py", line 462, in decoder_loss_step
current_loss_value = current_loss(
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/nemo/core/classes/common.py", line 1087, in __call__
outputs = wrapped(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/nemo/collections/asr/losses/ssl_losses/contrastive.py", line 205, in forward
negatives, _ = self.sample_negatives(targets_masked_only, targets_masked_only.size(0)) # T'xBxC # T'
File "/usr/local/lib/python3.10/dist-packages/nemo/collections/asr/losses/ssl_losses/contrastive.py", line 139, in sample_negatives
neg_idxs = torch.multinomial(torch.ones((num, high), device=y.device), self.num_negatives)
RuntimeError: cannot sample n_sample > prob_dist.size(-1) samples without replacement
Set the environment variable HYDRA_FULL_ERROR=1 for a complete stack trace.
Any lead would be appreciated. Thank you