NeMo icon indicating copy to clipboard operation
NeMo copied to clipboard

RuntimeError for contrastive loss if audio duration is not long enough for num_negatives

Open piraka9011 opened this issue 3 years ago • 1 comments

Describe the bug

If the duration of an audio file is <8s (unsure exactly, need further testing), then sample_negatives will throw the following error:

File "/opt/conda/lib/python3.8/site-packages/nemo/collections/asr/models/ssl_models.py", line 491, in validation_step                                                                                                      
  loss_value, _ = self.decoder_loss_step(spectrograms, spec_masks, encoded, encoded_len, targets, target_lengths)                                                                                                          
File "/opt/conda/lib/python3.8/site-packages/nemo/collections/asr/models/ssl_models.py", line 450, in decoder_loss_step                                                                                                    
  current_loss_value = current_loss(                                                                                                                                                                                       
File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1129, in _call_impl                                                                                                                         
  return forward_call(*input, **kwargs)                                                                                                                                                                                    
File "/opt/conda/lib/python3.8/site-packages/nemo/core/classes/common.py", line 963, in __call__                                                                                                                           
  outputs = wrapped(*args, **kwargs)                                                                                                                                                                                       
File "/opt/conda/lib/python3.8/site-packages/nemo/collections/asr/losses/ssl_losses/contrastive.py", line 205, in forward                                                                                                  
  negatives, _ = self.sample_negatives(targets_masked_only, targets_masked_only.size(0))  # T'xBxC  # T'                                                                                                                   
File "/opt/conda/lib/python3.8/site-packages/nemo/collections/asr/losses/ssl_losses/contrastive.py", line 139, in sample_negatives                                                                                         
  neg_idxs = torch.multinomial(torch.ones((num, high), device=y.device), self.num_negatives)                                                                                                                               
RuntimeError: cannot sample n_sample > prob_dist.size(-1) samples without replacement  

If you haven't read the docs or understand the code/loss function, then it might not be clear that you need to adjust the min_duration or num_negatives

Steps/Code to reproduce bug

Run any of the examples in speech_pretraining with train/validation_ds.min_duration < 8.0

Expected behavior

While this is expected to fail, I would like to propose that we catch this error and replace with something more...constructive. I can open a PR, but wanted to see if the maintainers are open to the suggestion.

My suggestion is to re-raise with a NeMoBaseException (or make a new one? LossMisconfigurationError?) with the suggestion to either adjust num_negative or the min_duration.

Edit: I realized I opened this as a bug, but it should've been a feature suggestion.

piraka9011 avatar Jul 22 '22 18:07 piraka9011

@piraka9011 I think this would be a useful thing to catch, please make a PR for your suggestion

sam1373 avatar Jul 25 '22 17:07 sam1373

This issue is stale because it has been open for 60 days with no activity.

github-actions[bot] avatar Sep 28 '22 02:09 github-actions[bot]

This issue was closed because it has been inactive for 7 days since being marked as stale.

github-actions[bot] avatar Oct 06 '22 02:10 github-actions[bot]

Hi, I'm facing the same issue with the default conformer config(changed min_duration=1.6 and max_duration=30.0).

  File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/batch/training_batch_loop.py", line 88, in advance
    outputs = self.optimizer_loop.run(optimizers, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/loop.py", line 199, in run
    self.advance(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 200, in advance
    result = self._run_optimization(kwargs, self._optimizers[self.optim_progress.optimizer_position])
  File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 239, in _run_optimization
    closure()
  File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 147, in __call__
    self._result = self.closure(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 133, in closure
    step_output = self._step_fn()
  File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 406, in _training_step
    training_step_output = self.trainer._call_strategy_hook("training_step", *kwargs.values())
  File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/trainer.py", line 1480, in _call_strategy_hook
    output = fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/strategies/ddp.py", line 352, in training_step
    return self.model(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/parallel/distributed.py", line 1040, in forward
    output = self._run_ddp_forward(*inputs, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/parallel/distributed.py", line 1000, in _run_ddp_forward
    return module_to_run(*inputs[0], **kwargs[0])
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/overrides/base.py", line 98, in forward
    output = self._forward_module.training_step(*inputs, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/nemo/utils/model_utils.py", line 362, in wrap_training_step
    output_dict = wrapped(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/nemo/collections/asr/models/ssl_models.py", line 495, in training_step
    loss_value, loss_val_dict = self.decoder_loss_step(
  File "/usr/local/lib/python3.10/dist-packages/nemo/collections/asr/models/ssl_models.py", line 462, in decoder_loss_step
    current_loss_value = current_loss(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/nemo/core/classes/common.py", line 1087, in __call__
    outputs = wrapped(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/nemo/collections/asr/losses/ssl_losses/contrastive.py", line 205, in forward
    negatives, _ = self.sample_negatives(targets_masked_only, targets_masked_only.size(0))  # T'xBxC  # T'
  File "/usr/local/lib/python3.10/dist-packages/nemo/collections/asr/losses/ssl_losses/contrastive.py", line 139, in sample_negatives
    neg_idxs = torch.multinomial(torch.ones((num, high), device=y.device), self.num_negatives)
RuntimeError: cannot sample n_sample > prob_dist.size(-1) samples without replacement
Set the environment variable HYDRA_FULL_ERROR=1 for a complete stack trace.

Any lead would be appreciated. Thank you

anjul1008 avatar Mar 17 '23 07:03 anjul1008