mamba icon indicating copy to clipboard operation
mamba copied to clipboard

When training with MambaLMHead with Huggingface Trainer, getting unexpected `attention_mask`

Open srijiths opened this issue 1 year ago • 0 comments

  • attention_mask is poped from the tockenizer
  • With huggingface Trainer, i was getting the same issue, then created a custom MambaTrainer like this https://github.com/redotvideo/mamba-chat/blob/main/trainer/mamba_trainer.py. Then attention_mask issue resolved during the first iteration of the training.
  • But, its occurring at the evaluation stage again.

Any pointers on how to resolve this ? Thank You !!!

The following columns in the evaluation set don't have a corresponding argument in `MambaLMHeadModel.forward` and have been ignored: labels, __index_level_0__, conversion, token_type_ids, user_journey. If labels, __index_level_0__, conversion, token_type_ids, user_journey are not expected by `MambaLMHeadModel.forward`,  you can safely ignore this message.

***** Running Evaluation *****
  Num examples = 5645
  Batch size = 32
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-95-3435b262f1ae> in <cell line: 1>()
----> 1 trainer.train()

/usr/local/lib/python3.10/dist-packages/transformers/trainer.py in train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)
   2162                 hf_hub_utils.enable_progress_bars()
   2163         else:
-> 2164             return inner_training_loop(
   2165                 args=args,
   2166                 resume_from_checkpoint=resume_from_checkpoint,

/usr/local/lib/python3.10/dist-packages/transformers/trainer.py in _inner_training_loop(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)
   2587                         self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch
   2588                         self.control = self.callback_handler.on_step_end(args, self.state, self.control)
-> 2589                         self._maybe_log_save_evaluate(
   2590                             tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time
   2591                         )

/usr/local/lib/python3.10/dist-packages/transformers/trainer.py in _maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time)
   3045         metrics = None
   3046         if self.control.should_evaluate:
-> 3047             metrics = self._evaluate(trial, ignore_keys_for_eval)
   3048             is_new_best_metric = self._determine_best_metric(metrics=metrics, trial=trial)
   3049 

/usr/local/lib/python3.10/dist-packages/transformers/trainer.py in _evaluate(self, trial, ignore_keys_for_eval, skip_scheduler)
   2999 
   3000     def _evaluate(self, trial, ignore_keys_for_eval, skip_scheduler=False):
-> 3001         metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
   3002         self._report_to_hp_search(trial, self.state.global_step, metrics)
   3003 

/usr/local/lib/python3.10/dist-packages/transformers/trainer.py in evaluate(self, eval_dataset, ignore_keys, metric_key_prefix)
   4049 
   4050         eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop
-> 4051         output = eval_loop(
   4052             eval_dataloader,
   4053             description="Evaluation",

/usr/local/lib/python3.10/dist-packages/transformers/trainer.py in evaluation_loop(self, dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix)
   4243 
   4244             # Prediction step
-> 4245             losses, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
   4246             main_input_name = getattr(self.model, "main_input_name", "input_ids")
   4247             inputs_decode = (

/usr/local/lib/python3.10/dist-packages/transformers/trainer.py in prediction_step(self, model, inputs, prediction_loss_only, ignore_keys)
   4469                     loss = None
   4470                     with self.compute_loss_context_manager():
-> 4471                         outputs = model(**inputs)
   4472                     if isinstance(outputs, dict):
   4473                         logits = tuple(v for k, v in outputs.items() if k not in ignore_keys)

/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs)
   1734             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735         else:
-> 1736             return self._call_impl(*args, **kwargs)
   1737 
   1738     # torchrec tests the code consistency with the following code

/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
   1745                 or _global_backward_pre_hooks or _global_backward_hooks
   1746                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747             return forward_call(*args, **kwargs)
   1748 
   1749         result = None

/usr/local/lib/python3.10/dist-packages/torch/nn/parallel/data_parallel.py in forward(self, *inputs, **kwargs)
    191                 return self.module(*inputs[0], **module_kwargs[0])
    192             replicas = self.replicate(self.module, self.device_ids[: len(inputs)])
--> 193             outputs = self.parallel_apply(replicas, inputs, module_kwargs)
    194             return self.gather(outputs, self.output_device)
    195 

/usr/local/lib/python3.10/dist-packages/torch/nn/parallel/data_parallel.py in parallel_apply(self, replicas, inputs, kwargs)
    210         self, replicas: Sequence[T], inputs: Sequence[Any], kwargs: Any
    211     ) -> List[Any]:
--> 212         return parallel_apply(
    213             replicas, inputs, kwargs, self.device_ids[: len(replicas)]
    214         )

/usr/local/lib/python3.10/dist-packages/torch/nn/parallel/parallel_apply.py in parallel_apply(modules, inputs, kwargs_tup, devices)
    124         output = results[i]
    125         if isinstance(output, ExceptionWrapper):
--> 126             output.reraise()
    127         outputs.append(output)
    128     return outputs

/usr/local/lib/python3.10/dist-packages/torch/_utils.py in reraise(self)
    713             # instantiate since we don't know how to
    714             raise RuntimeError(msg) from None
--> 715         raise exception
    716 
    717 

TypeError: Caught TypeError in replica 0 on device 0.
Original Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/parallel/parallel_apply.py", line 96, in _worker
    output = module(*input, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/models/mixer_seq_simple.py", line 279, in forward
    hidden_states = self.backbone(input_ids, inference_params=inference_params, **mixer_kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/models/mixer_seq_simple.py", line 194, in forward
    hidden_states, residual = layer(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/modules/block.py", line 67, in forward
    hidden_states = self.mixer(hidden_states, inference_params=inference_params, **mixer_kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
TypeError: Mamba.forward() got an unexpected keyword argument 'attention_mask'

srijiths avatar Feb 15 '25 22:02 srijiths