mamba
mamba copied to clipboard
When training with MambaLMHead with Huggingface Trainer, getting unexpected `attention_mask`
-
attention_maskis poped from the tockenizer - With huggingface Trainer, i was getting the same issue, then created a custom MambaTrainer like this https://github.com/redotvideo/mamba-chat/blob/main/trainer/mamba_trainer.py. Then attention_mask issue resolved during the first iteration of the training.
- But, its occurring at the evaluation stage again.
Any pointers on how to resolve this ? Thank You !!!
The following columns in the evaluation set don't have a corresponding argument in `MambaLMHeadModel.forward` and have been ignored: labels, __index_level_0__, conversion, token_type_ids, user_journey. If labels, __index_level_0__, conversion, token_type_ids, user_journey are not expected by `MambaLMHeadModel.forward`, you can safely ignore this message.
***** Running Evaluation *****
Num examples = 5645
Batch size = 32
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-95-3435b262f1ae> in <cell line: 1>()
----> 1 trainer.train()
/usr/local/lib/python3.10/dist-packages/transformers/trainer.py in train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)
2162 hf_hub_utils.enable_progress_bars()
2163 else:
-> 2164 return inner_training_loop(
2165 args=args,
2166 resume_from_checkpoint=resume_from_checkpoint,
/usr/local/lib/python3.10/dist-packages/transformers/trainer.py in _inner_training_loop(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)
2587 self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch
2588 self.control = self.callback_handler.on_step_end(args, self.state, self.control)
-> 2589 self._maybe_log_save_evaluate(
2590 tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time
2591 )
/usr/local/lib/python3.10/dist-packages/transformers/trainer.py in _maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time)
3045 metrics = None
3046 if self.control.should_evaluate:
-> 3047 metrics = self._evaluate(trial, ignore_keys_for_eval)
3048 is_new_best_metric = self._determine_best_metric(metrics=metrics, trial=trial)
3049
/usr/local/lib/python3.10/dist-packages/transformers/trainer.py in _evaluate(self, trial, ignore_keys_for_eval, skip_scheduler)
2999
3000 def _evaluate(self, trial, ignore_keys_for_eval, skip_scheduler=False):
-> 3001 metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
3002 self._report_to_hp_search(trial, self.state.global_step, metrics)
3003
/usr/local/lib/python3.10/dist-packages/transformers/trainer.py in evaluate(self, eval_dataset, ignore_keys, metric_key_prefix)
4049
4050 eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop
-> 4051 output = eval_loop(
4052 eval_dataloader,
4053 description="Evaluation",
/usr/local/lib/python3.10/dist-packages/transformers/trainer.py in evaluation_loop(self, dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix)
4243
4244 # Prediction step
-> 4245 losses, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
4246 main_input_name = getattr(self.model, "main_input_name", "input_ids")
4247 inputs_decode = (
/usr/local/lib/python3.10/dist-packages/transformers/trainer.py in prediction_step(self, model, inputs, prediction_loss_only, ignore_keys)
4469 loss = None
4470 with self.compute_loss_context_manager():
-> 4471 outputs = model(**inputs)
4472 if isinstance(outputs, dict):
4473 logits = tuple(v for k, v in outputs.items() if k not in ignore_keys)
/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs)
1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1735 else:
-> 1736 return self._call_impl(*args, **kwargs)
1737
1738 # torchrec tests the code consistency with the following code
/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
1745 or _global_backward_pre_hooks or _global_backward_hooks
1746 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747 return forward_call(*args, **kwargs)
1748
1749 result = None
/usr/local/lib/python3.10/dist-packages/torch/nn/parallel/data_parallel.py in forward(self, *inputs, **kwargs)
191 return self.module(*inputs[0], **module_kwargs[0])
192 replicas = self.replicate(self.module, self.device_ids[: len(inputs)])
--> 193 outputs = self.parallel_apply(replicas, inputs, module_kwargs)
194 return self.gather(outputs, self.output_device)
195
/usr/local/lib/python3.10/dist-packages/torch/nn/parallel/data_parallel.py in parallel_apply(self, replicas, inputs, kwargs)
210 self, replicas: Sequence[T], inputs: Sequence[Any], kwargs: Any
211 ) -> List[Any]:
--> 212 return parallel_apply(
213 replicas, inputs, kwargs, self.device_ids[: len(replicas)]
214 )
/usr/local/lib/python3.10/dist-packages/torch/nn/parallel/parallel_apply.py in parallel_apply(modules, inputs, kwargs_tup, devices)
124 output = results[i]
125 if isinstance(output, ExceptionWrapper):
--> 126 output.reraise()
127 outputs.append(output)
128 return outputs
/usr/local/lib/python3.10/dist-packages/torch/_utils.py in reraise(self)
713 # instantiate since we don't know how to
714 raise RuntimeError(msg) from None
--> 715 raise exception
716
717
TypeError: Caught TypeError in replica 0 on device 0.
Original Traceback (most recent call last):
File "/usr/local/lib/python3.10/dist-packages/torch/nn/parallel/parallel_apply.py", line 96, in _worker
output = module(*input, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/models/mixer_seq_simple.py", line 279, in forward
hidden_states = self.backbone(input_ids, inference_params=inference_params, **mixer_kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/models/mixer_seq_simple.py", line 194, in forward
hidden_states, residual = layer(
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/modules/block.py", line 67, in forward
hidden_states = self.mixer(hidden_states, inference_params=inference_params, **mixer_kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
TypeError: Mamba.forward() got an unexpected keyword argument 'attention_mask'