[BUG] `zero_quantized_nontrainable_weights=True` when using PEFT+DeepSpeed with Mixed-Precision training using BF16 leads to `float != c10::BFloat16` error
Describe the bug
zero_quantized_nontrainable_weights=True when using PEFT+DeepSpeed with Mixed-Precision training using BF16 leads to float != c10::BFloat16 error
To Reproduce Steps to reproduce the behavior:
- DeepSpeed Config: https://github.com/pacman100/DHS-LLM-Workshop/blob/main/chat_assistant/training/configs/ds_config_z3_lora.json
- Accelerate Config: https://github.com/pacman100/DHS-LLM-Workshop/blob/main/chat_assistant/training/configs/deepspeed_zeropp_lora_config.yaml
- Code: https://github.com/pacman100/DHS-LLM-Workshop/blob/main/chat_assistant/training/train.py
- Launch Command: https://github.com/pacman100/DHS-LLM-Workshop/blob/main/chat_assistant/training/run_peft_deepspeed_zeropp.sh
- Infrastructure: 8 80GB GPUs.
- Output logs with error:
"/fsx/sourab/miniconda3/envs/hf/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1568, in _call_impl
63 result = forward_call(*args, **kwargs)
64 File "/fsx/sourab/transformers/src/transformers/models/mistral/modeling_mistral.py", line 356, in forward
65 query_states = self.q_proj(hidden_states)
66 File "/fsx/sourab/miniconda3/envs/hf/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
67 return self._call_impl(*args, **kwargs)
68 File "/fsx/sourab/miniconda3/envs/hf/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1568, in _call_impl
69 result = forward_call(*args, **kwargs)
70 File "/fsx/sourab/peft/src/peft/tuners/lora/layer.py", line 309, in forward
71 result = self.base_layer(x, *args, **kwargs)
72 File "/fsx/sourab/miniconda3/envs/hf/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
73 return self._call_impl(*args, **kwargs)
74 File "/fsx/sourab/miniconda3/envs/hf/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1568, in _call_impl
75 result = forward_call(*args, **kwargs)
76 File "/fsx/sourab/miniconda3/envs/hf/lib/python3.10/site-packages/torch/nn/modules/linear.py", line 114, in forward
77 return F.linear(input, self.weight, self.bias)
78 File "/fsx/sourab/miniconda3/envs/hf/lib/python3.10/site-packages/deepspeed/runtime/zero/linear.py", line 109, in zero3_linear_wrap
79 return LinearFunctionForZeroStage3.apply(input, weight)
80 File "/fsx/sourab/miniconda3/envs/hf/lib/python3.10/site-packages/torch/autograd/function.py", line 539, in apply
81 return super().apply(*args, **kwargs) # type: ignore[misc]
82 File "/fsx/sourab/miniconda3/envs/hf/lib/python3.10/site-packages/torch/cuda/amp/autocast_mode.py", line 113, in decorate_fwd
83 return fwd(*args, **kwargs)
84 File "/fsx/sourab/miniconda3/envs/hf/lib/python3.10/site-packages/deepspeed/runtime/zero/linear.py", line 57, in forward
85 output = input.matmul(weight.t())
86 RuntimeError: expected mat1 and mat2 to have the same dtype, but got: float != c10::BFloat16
Expected behavior
When using PEFT LoRA with DeepSpeed along with the feature zero_quantized_nontrainable_weights, it should lead to non-trainable weights being quantized resulting in a lot of memory savings. This would enable even larger model fine-tuning or large batch sizes.
ds_report output
--------------------------------------------------
DeepSpeed C++/CUDA extension op report
--------------------------------------------------
NOTE: Ops not installed will be just-in-time (JIT) compiled at
runtime if needed. Op compatibility means that your system
meet the required dependencies to JIT install the op.
--------------------------------------------------
JIT compiled ops requires ninja
ninja .................. [OKAY]
--------------------------------------------------
op name ................ installed .. compatible
--------------------------------------------------
[WARNING] async_io requires the dev libaio .so object and headers but these were not found.
[WARNING] async_io: please install the libaio-dev package with apt
[WARNING] If libaio is already installed (perhaps from source), try setting the CFLAGS and LDFLAGS environment variables to where it can be found.
async_io ............... [NO] ....... [NO]
fused_adam ............. [NO] ....... [OKAY]
cpu_adam ............... [NO] ....... [OKAY]
cpu_adagrad ............ [NO] ....... [OKAY]
cpu_lion ............... [NO] ....... [OKAY]
[WARNING] Please specify the CUTLASS repo directory as environment variable $CUTLASS_PATH
evoformer_attn ......... [NO] ....... [NO]
fused_lamb ............. [NO] ....... [OKAY]
fused_lion ............. [NO] ....... [OKAY]
inference_core_ops ..... [NO] ....... [OKAY]
cutlass_ops ............ [NO] ....... [OKAY]
quantizer .............. [NO] ....... [OKAY]
ragged_device_ops ...... [NO] ....... [OKAY]
ragged_ops ............. [NO] ....... [OKAY]
random_ltd ............. [NO] ....... [OKAY]
[WARNING] sparse_attn requires a torch version >= 1.5 and < 2.0 but detected 2.1
[WARNING] using untested triton version (2.1.0), only 1.0.0 is known to be compatible
sparse_attn ............ [NO] ....... [NO]
spatial_inference ...... [NO] ....... [OKAY]
transformer ............ [NO] ....... [OKAY]
stochastic_transformer . [NO] ....... [OKAY]
transformer_inference .. [NO] ....... [OKAY]
--------------------------------------------------
DeepSpeed general environment info:
torch install path ............... ['/fsx/sourab/miniconda3/envs/hf/lib/python3.10/site-packages/torch']
torch version .................... 2.1.2+cu121
deepspeed install path ........... ['/fsx/sourab/miniconda3/envs/hf/lib/python3.10/site-packages/deepspeed']
deepspeed info ................... 0.12.5, unknown, unknown
torch cuda version ............... 12.1
torch hip version ................ None
nvcc version ..................... 12.1
deepspeed wheel compiled w. ...... torch 2.1, cuda 12.1
shared memory (/dev/shm) size .... 999.99 GB
System info (please complete the following information):
- OS: Ubuntu 20.04.6 LTS
- GPU count and types One machine with x8 H100s each
- Python version 3.10.13
Launcher context Accelerate launcher which internally uses the DeepSpeed launcher.
would love to see this fixed for training MOEs on deepspeed with quantization + bf16
Same issue here. Training w/ BF16 + PeFT and Zero3++:
Stack trace:
Traceback (most recent call last):
File "/nethome/yheng6/Co-Adaptation/src/co_adapt/reward_ensemble/stablelm_zephyr_3b/train.py", line 455, in <module>
main()
File "/nethome/yheng6/Co-Adaptation/src/co_adapt/reward_ensemble/stablelm_zephyr_3b/train.py", line 400, in main
train_result = trainer.train(resume_from_checkpoint=checkpoint)
File "/nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/transformers/trainer.py", line 1859, in train
return inner_training_loop(
File "/nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/transformers/trainer.py", line 2203, in _inner_training_loop
tr_loss_step = self.training_step(model, inputs)
File "/nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/transformers/trainer.py", line 3138, in training_step
loss = self.compute_loss(model, inputs)
File "/nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/trl/trainer/dpo_trainer.py", line 1081, in compute_loss
loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="train")
File "/nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/trl/trainer/dpo_trainer.py", line 1022, in get_batch_loss_metrics
) = self.concatenated_forward(model, batch)
File "/nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/trl/trainer/dpo_trainer.py", line 985, in concatenated_forward
all_logits = model(
File "/nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
ret_val = func(*args, **kwargs)
File "/nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 1814, in forward
loss = self.module(*inputs, **kwargs)
File "/nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1568, in _call_impl
result = forward_call(*args, **kwargs)
File "/nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/peft/peft_model.py", line 1129, in forward
return self.base_model(
File "/nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1568, in _call_impl
result = forward_call(*args, **kwargs)
File "/nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/peft/tuners/tuners_utils.py", line 161, in forward
return self.model.forward(*args, **kwargs)
File "/nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/transformers/models/stablelm/modeling_stablelm.py", line 1158, in forward
outputs = self.model(
File "/nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1568, in _call_impl
result = forward_call(*args, **kwargs)
File "/nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/transformers/models/stablelm/modeling_stablelm.py", line 1026, in forward
layer_outputs = self._gradient_checkpointing_func(
File "/nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/_compile.py", line 24, in inner
return torch._dynamo.disable(fn, recursive)(*args, **kwargs)
File "/nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 328, in _fn
return fn(*args, **kwargs)
File "/nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/_dynamo/external_utils.py", line 17, in inner
return fn(*args, **kwargs)
File "/nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 458, in checkpoint
ret = function(*args, **kwargs)
File "/nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1568, in _call_impl
result = forward_call(*args, **kwargs)
File "/nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/transformers/models/stablelm/modeling_stablelm.py", line 759, in forward
self_attn_output, self_attn_weights, present_key_value = self.self_attn(
File "/nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1568, in _call_impl
result = forward_call(*args, **kwargs)
File "/nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/transformers/models/stablelm/modeling_stablelm.py", line 535, in forward
query_states = self.q_proj(hidden_states)
File "/nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1568, in _call_impl
result = forward_call(*args, **kwargs)
File "/nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/peft/tuners/lora/layer.py", line 509, in forward
result = result + lora_B(lora_A(dropout(x))) * scaling
File "/nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1568, in _call_impl
result = forward_call(*args, **kwargs)
File "/nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/linear.py", line 114, in forward
return F.linear(input, self.weight, self.bias)
File "/nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/deepspeed/runtime/zero/linear.py", line 109, in zero3_linear_wrap
return LinearFunctionForZeroStage3.apply(input, weight)
File "/nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/autograd/function.py", line 539, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
File "/nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/cuda/amp/autocast_mode.py", line 113, in decorate_fwd
return fwd(*args, **kwargs)
File "/nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/deepspeed/runtime/zero/linear.py", line 57, in forward
output = input.matmul(weight.t())
RuntimeError: expected mat1 and mat2 to have the same dtype, but got: c10::BFloat16 != c10::Half
╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ /nethome/yheng6/Co-Adaptation/src/co_adapt/reward_ensemble/stablelm_zephyr_3b/train.py:455 in │
│ <module> │
│ │
│ 452 │
│ 453 │
│ 454 if __name__ == "__main__": │
│ ❱ 455 │ main() │
│ 456 │
│ │
│ /nethome/yheng6/Co-Adaptation/src/co_adapt/reward_ensemble/stablelm_zephyr_3b/train.py:400 in │
│ main │
│ │
│ 397 │ │ checkpoint = training_args.resume_from_checkpoint │
│ 398 │ elif last_checkpoint is not None: │
│ 399 │ │ checkpoint = last_checkpoint │
│ ❱ 400 │ train_result = trainer.train(resume_from_checkpoint=checkpoint) │
│ 401 │ metrics = train_result.metrics │
│ 402 │ metrics["train_samples"] = len(raw_datasets["train"]) │
│ 403 │ trainer.log_metrics("train", metrics) │
│ │
│ /nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/transformers/trainer.py:18 │
│ 59 in train │
│ │
│ 1856 │ │ │ finally: │
│ 1857 │ │ │ │ hf_hub_utils.enable_progress_bars() │
│ 1858 │ │ else: │
│ ❱ 1859 │ │ │ return inner_training_loop( │
│ 1860 │ │ │ │ args=args, │
│ 1861 │ │ │ │ resume_from_checkpoint=resume_from_checkpoint, │
│ 1862 │ │ │ │ trial=trial, │
│ │
│ /nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/transformers/trainer.py:22 │
│ 03 in _inner_training_loop │
│ │
│ 2200 │ │ │ │ │ self.control = self.callback_handler.on_step_begin(args, self.state, │
│ 2201 │ │ │ │ │
│ 2202 │ │ │ │ with self.accelerator.accumulate(model): │
│ ❱ 2203 │ │ │ │ │ tr_loss_step = self.training_step(model, inputs) │
│ 2204 │ │ │ │ │
│ 2205 │ │ │ │ if ( │
│ 2206 │ │ │ │ │ args.logging_nan_inf_filter │
│ │
│ /nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/transformers/trainer.py:31 │
│ 38 in training_step │
│ │
│ 3135 │ │ │ return loss_mb.reduce_mean().detach().to(self.args.device) │
│ 3136 │ │ │
│ 3137 │ │ with self.compute_loss_context_manager(): │
│ ❱ 3138 │ │ │ loss = self.compute_loss(model, inputs) │
│ 3139 │ │ │
│ 3140 │ │ if self.args.n_gpu > 1: │
│ 3141 │ │ │ loss = loss.mean() # mean() to average on multi-gpu parallel training │
│ │
│ /nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/trl/trainer/dpo_trainer.py │
│ :1081 in compute_loss │
│ │
│ 1078 │ │ compute_loss_context_manager = torch.cuda.amp.autocast if self._peft_has_been_ca │
│ 1079 │ │ │
│ 1080 │ │ with compute_loss_context_manager(): │
│ ❱ 1081 │ │ │ loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="train │
│ 1082 │ │ │
│ 1083 │ │ # Make sure to move the loss to the device the original accumulating loss is at │
│ 1084 │ │ loss = loss.to(self.args.device) │
│ │
│ /nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/trl/trainer/dpo_trainer.py │
│ :1022 in get_batch_loss_metrics │
│ │
│ 1019 │ │ │ policy_rejected_logps, │
│ 1020 │ │ │ policy_chosen_logits, │
│ 1021 │ │ │ policy_rejected_logits, │
│ ❱ 1022 │ │ ) = self.concatenated_forward(model, batch) │
│ 1023 │ │ │
│ 1024 │ │ # if reference_chosen_logps and reference_rejected_logps in batch use them, othe │
│ 1025 │ │ if "reference_chosen_logps" in batch and "reference_rejected_logps" in batch: │
│ │
│ /nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/trl/trainer/dpo_trainer.py │
│ :985 in concatenated_forward │
│ │
│ 982 │ │ │ if self.is_encoder_decoder │
│ 983 │ │ │ else {} │
│ 984 │ │ ) │
│ ❱ 985 │ │ all_logits = model( │
│ 986 │ │ │ concatenated_batch["concatenated_input_ids"], │
│ 987 │ │ │ attention_mask=concatenated_batch["concatenated_attention_mask"], │
│ 988 │ │ │ use_cache=False, │
│ │
│ /nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py │
│ :1518 in _wrapped_call_impl │
│ │
│ 1515 │ │ if self._compiled_call_impl is not None: │
│ 1516 │ │ │ return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] │
│ 1517 │ │ else: │
│ ❱ 1518 │ │ │ return self._call_impl(*args, **kwargs) │
│ 1519 │ │
│ 1520 │ def _call_impl(self, *args, **kwargs): │
│ 1521 │ │ forward_call = (self._slow_forward if torch._C._get_tracing_state() else self.fo │
│ │
│ /nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py │
│ :1527 in _call_impl │
│ │
│ 1524 │ │ if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks │
│ 1525 │ │ │ │ or _global_backward_pre_hooks or _global_backward_hooks │
│ 1526 │ │ │ │ or _global_forward_hooks or _global_forward_pre_hooks): │
│ ❱ 1527 │ │ │ return forward_call(*args, **kwargs) │
│ 1528 │ │ │
│ 1529 │ │ try: │
│ 1530 │ │ │ result = None │
│ │
│ /nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/deepspeed/utils/nvtx.py:15 │
│ in wrapped_fn │
│ │
│ 12 │ │
│ 13 │ def wrapped_fn(*args, **kwargs): │
│ 14 │ │ get_accelerator().range_push(func.__qualname__) │
│ ❱ 15 │ │ ret_val = func(*args, **kwargs) │
│ 16 │ │ get_accelerator().range_pop() │
│ 17 │ │ return ret_val │
│ 18 │
│ │
│ /nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/deepspeed/runtime/engine.p │
│ y:1814 in forward │
│ │
│ 1811 │ │ if self.fp16_auto_cast(): │
│ 1812 │ │ │ inputs = self._cast_inputs_half(inputs) │
│ 1813 │ │ │
│ ❱ 1814 │ │ loss = self.module(*inputs, **kwargs) │
│ 1815 │ │ │
│ 1816 │ │ if self.zero_optimization_partition_weights(): │
│ 1817 │ │ │ # Disable automated discovery of external parameters │
│ │
│ /nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py │
│ :1518 in _wrapped_call_impl │
│ │
│ 1515 │ │ if self._compiled_call_impl is not None: │
│ 1516 │ │ │ return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] │
│ 1517 │ │ else: │
│ ❱ 1518 │ │ │ return self._call_impl(*args, **kwargs) │
│ 1519 │ │
│ 1520 │ def _call_impl(self, *args, **kwargs): │
│ 1521 │ │ forward_call = (self._slow_forward if torch._C._get_tracing_state() else self.fo │
│ │
│ /nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py │
│ :1568 in _call_impl │
│ │
│ 1565 │ │ │ │ bw_hook = hooks.BackwardHook(self, full_backward_hooks, backward_pre_hoo │
│ 1566 │ │ │ │ args = bw_hook.setup_input_hook(args) │
│ 1567 │ │ │ │
│ ❱ 1568 │ │ │ result = forward_call(*args, **kwargs) │
│ 1569 │ │ │ if _global_forward_hooks or self._forward_hooks: │
│ 1570 │ │ │ │ for hook_id, hook in ( │
│ 1571 │ │ │ │ │ *_global_forward_hooks.items(), │
│ │
│ /nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/peft/peft_model.py:1129 in │
│ forward │
│ │
│ 1126 │ │ │ │
│ 1127 │ │ │ with self._enable_peft_forward_hooks(**kwargs): │
│ 1128 │ │ │ │ kwargs = {k: v for k, v in kwargs.items() if k not in self.special_peft_ │
│ ❱ 1129 │ │ │ │ return self.base_model( │
│ 1130 │ │ │ │ │ input_ids=input_ids, │
│ 1131 │ │ │ │ │ attention_mask=attention_mask, │
│ 1132 │ │ │ │ │ inputs_embeds=inputs_embeds, │
│ │
│ /nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py │
│ :1518 in _wrapped_call_impl │
│ │
│ 1515 │ │ if self._compiled_call_impl is not None: │
│ 1516 │ │ │ return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] │
│ 1517 │ │ else: │
│ ❱ 1518 │ │ │ return self._call_impl(*args, **kwargs) │
│ 1519 │ │
│ 1520 │ def _call_impl(self, *args, **kwargs): │
│ 1521 │ │ forward_call = (self._slow_forward if torch._C._get_tracing_state() else self.fo │
│ │
│ /nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py │
│ :1568 in _call_impl │
│ │
│ 1565 │ │ │ │ bw_hook = hooks.BackwardHook(self, full_backward_hooks, backward_pre_hoo │
│ 1566 │ │ │ │ args = bw_hook.setup_input_hook(args) │
│ 1567 │ │ │ │
│ ❱ 1568 │ │ │ result = forward_call(*args, **kwargs) │
│ 1569 │ │ │ if _global_forward_hooks or self._forward_hooks: │
│ 1570 │ │ │ │ for hook_id, hook in ( │
│ 1571 │ │ │ │ │ *_global_forward_hooks.items(), │
│ │
│ /nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/peft/tuners/tuners_utils.p │
│ y:161 in forward │
│ │
│ 158 │ │ return self.active_adapter │
│ 159 │ │
│ 160 │ def forward(self, *args: Any, **kwargs: Any): │
│ ❱ 161 │ │ return self.model.forward(*args, **kwargs) │
│ 162 │ │
│ 163 │ @abstractmethod │
│ 164 │ def _prepare_adapter_config(self, peft_config: PeftConfig, model_config: dict) -> Pe │
│ │
│ /nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/transformers/models/stable │
│ lm/modeling_stablelm.py:1158 in forward │
│ │
│ 1155 │ │ ) │
│ 1156 │ │ return_dict = return_dict if return_dict is not None else self.config.use_return │
│ 1157 │ │ │
│ ❱ 1158 │ │ outputs = self.model( │
│ 1159 │ │ │ input_ids=input_ids, │
│ 1160 │ │ │ attention_mask=attention_mask, │
│ 1161 │ │ │ position_ids=position_ids, │
│ │
│ /nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py │
│ :1518 in _wrapped_call_impl │
│ │
│ 1515 │ │ if self._compiled_call_impl is not None: │
│ 1516 │ │ │ return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] │
│ 1517 │ │ else: │
│ ❱ 1518 │ │ │ return self._call_impl(*args, **kwargs) │
│ 1519 │ │
│ 1520 │ def _call_impl(self, *args, **kwargs): │
│ 1521 │ │ forward_call = (self._slow_forward if torch._C._get_tracing_state() else self.fo │
│ │
│ /nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py │
│ :1568 in _call_impl │
│ │
│ 1565 │ │ │ │ bw_hook = hooks.BackwardHook(self, full_backward_hooks, backward_pre_hoo │
│ 1566 │ │ │ │ args = bw_hook.setup_input_hook(args) │
│ 1567 │ │ │ │
│ ❱ 1568 │ │ │ result = forward_call(*args, **kwargs) │
│ 1569 │ │ │ if _global_forward_hooks or self._forward_hooks: │
│ 1570 │ │ │ │ for hook_id, hook in ( │
│ 1571 │ │ │ │ │ *_global_forward_hooks.items(), │
│ │
│ /nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/transformers/models/stable │
│ lm/modeling_stablelm.py:1026 in forward │
│ │
│ 1023 │ │ │ │ all_hidden_states += (hidden_states,) │
│ 1024 │ │ │ │
│ 1025 │ │ │ if self.gradient_checkpointing and self.training: │
│ ❱ 1026 │ │ │ │ layer_outputs = self._gradient_checkpointing_func( │
│ 1027 │ │ │ │ │ decoder_layer.__call__, │
│ 1028 │ │ │ │ │ hidden_states, │
│ 1029 │ │ │ │ │ attention_mask, │
│ │
│ /nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/_compile.py:24 in │
│ inner │
│ │
│ 21 │ │ def inner(*args, **kwargs): │
│ 22 │ │ │ import torch._dynamo │
│ 23 │ │ │ │
│ ❱ 24 │ │ │ return torch._dynamo.disable(fn, recursive)(*args, **kwargs) │
│ 25 │ │ │
│ 26 │ │ return inner │
│ 27 │ else: │
│ │
│ /nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/_dynamo/eval_frame.p │
│ y:328 in _fn │
│ │
│ 325 │ │ │ dynamic_ctx = enable_dynamic(self.dynamic, self.export) │
│ 326 │ │ │ dynamic_ctx.__enter__() │
│ 327 │ │ │ try: │
│ ❱ 328 │ │ │ │ return fn(*args, **kwargs) │
│ 329 │ │ │ finally: │
│ 330 │ │ │ │ set_eval_frame(prior) │
│ 331 │ │ │ │ dynamic_ctx.__exit__(None, None, None) │
│ │
│ /nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/_dynamo/external_uti │
│ ls.py:17 in inner │
│ │
│ 14 │ │
│ 15 │ @functools.wraps(fn) │
│ 16 │ def inner(*args, **kwargs): │
│ ❱ 17 │ │ return fn(*args, **kwargs) │
│ 18 │ │
│ 19 │ return inner │
│ 20 │
│ │
│ /nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/utils/checkpoint.py: │
│ 458 in checkpoint │
│ │
│ 455 │ │ ) │
│ 456 │ │ # Runs pre-forward logic │
│ 457 │ │ next(gen) │
│ ❱ 458 │ │ ret = function(*args, **kwargs) │
│ 459 │ │ # Runs post-forward logic │
│ 460 │ │ try: │
│ 461 │ │ │ next(gen) │
│ │
│ /nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py │
│ :1518 in _wrapped_call_impl │
│ │
│ 1515 │ │ if self._compiled_call_impl is not None: │
│ 1516 │ │ │ return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] │
│ 1517 │ │ else: │
│ ❱ 1518 │ │ │ return self._call_impl(*args, **kwargs) │
│ 1519 │ │
│ 1520 │ def _call_impl(self, *args, **kwargs): │
│ 1521 │ │ forward_call = (self._slow_forward if torch._C._get_tracing_state() else self.fo │
│ │
│ /nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py │
│ :1568 in _call_impl │
│ │
│ 1565 │ │ │ │ bw_hook = hooks.BackwardHook(self, full_backward_hooks, backward_pre_hoo │
│ 1566 │ │ │ │ args = bw_hook.setup_input_hook(args) │
│ 1567 │ │ │ │
│ ❱ 1568 │ │ │ result = forward_call(*args, **kwargs) │
│ 1569 │ │ │ if _global_forward_hooks or self._forward_hooks: │
│ 1570 │ │ │ │ for hook_id, hook in ( │
│ 1571 │ │ │ │ │ *_global_forward_hooks.items(), │
│ │
│ /nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/transformers/models/stable │
│ lm/modeling_stablelm.py:759 in forward │
│ │
│ 756 │ │ hidden_states = self.input_layernorm(hidden_states) │
│ 757 │ │ │
│ 758 │ │ # Self Attention │
│ ❱ 759 │ │ self_attn_output, self_attn_weights, present_key_value = self.self_attn( │
│ 760 │ │ │ hidden_states=hidden_states, │
│ 761 │ │ │ attention_mask=attention_mask, │
│ 762 │ │ │ position_ids=position_ids, │
│ │
│ /nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py │
│ :1518 in _wrapped_call_impl │
│ │
│ 1515 │ │ if self._compiled_call_impl is not None: │
│ 1516 │ │ │ return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] │
│ 1517 │ │ else: │
│ ❱ 1518 │ │ │ return self._call_impl(*args, **kwargs) │
│ 1519 │ │
│ 1520 │ def _call_impl(self, *args, **kwargs): │
│ 1521 │ │ forward_call = (self._slow_forward if torch._C._get_tracing_state() else self.fo │
│ │
│ /nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py │
│ :1568 in _call_impl │
│ │
│ 1565 │ │ │ │ bw_hook = hooks.BackwardHook(self, full_backward_hooks, backward_pre_hoo │
│ 1566 │ │ │ │ args = bw_hook.setup_input_hook(args) │
│ 1567 │ │ │ │
│ ❱ 1568 │ │ │ result = forward_call(*args, **kwargs) │
│ 1569 │ │ │ if _global_forward_hooks or self._forward_hooks: │
│ 1570 │ │ │ │ for hook_id, hook in ( │
│ 1571 │ │ │ │ │ *_global_forward_hooks.items(), │
│ │
│ /nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/transformers/models/stable │
│ lm/modeling_stablelm.py:535 in forward │
│ │
│ 532 │ │ │
│ 533 │ │ bsz, q_len, _ = hidden_states.size() │
│ 534 │ │ │
│ ❱ 535 │ │ query_states = self.q_proj(hidden_states) │
│ 536 │ │ key_states = self.k_proj(hidden_states) │
│ 537 │ │ value_states = self.v_proj(hidden_states) │
│ 538 │
│ │
│ /nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py │
│ :1518 in _wrapped_call_impl │
│ │
│ 1515 │ │ if self._compiled_call_impl is not None: │
│ 1516 │ │ │ return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] │
│ 1517 │ │ else: │
│ ❱ 1518 │ │ │ return self._call_impl(*args, **kwargs) │
│ 1519 │ │
│ 1520 │ def _call_impl(self, *args, **kwargs): │
│ 1521 │ │ forward_call = (self._slow_forward if torch._C._get_tracing_state() else self.fo │
│ │
│ /nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py │
│ :1568 in _call_impl │
│ │
│ 1565 │ │ │ │ bw_hook = hooks.BackwardHook(self, full_backward_hooks, backward_pre_hoo │
│ 1566 │ │ │ │ args = bw_hook.setup_input_hook(args) │
│ 1567 │ │ │ │
│ ❱ 1568 │ │ │ result = forward_call(*args, **kwargs) │
│ 1569 │ │ │ if _global_forward_hooks or self._forward_hooks: │
│ 1570 │ │ │ │ for hook_id, hook in ( │
│ 1571 │ │ │ │ │ *_global_forward_hooks.items(), │
│ │
│ /nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/peft/tuners/lora/layer.py: │
│ 509 in forward │
│ │
│ 506 │ │ │ │ x = x.to(lora_A.weight.dtype) │
│ 507 │ │ │ │ │
│ 508 │ │ │ │ if not self.use_dora[active_adapter]: │
│ ❱ 509 │ │ │ │ │ result = result + lora_B(lora_A(dropout(x))) * scaling │
│ 510 │ │ │ │ else: │
│ 511 │ │ │ │ │ x = dropout(x) │
│ 512 │ │ │ │ │ result = result + self._apply_dora(x, lora_A, lora_B, scaling, activ │
│ │
│ /nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py │
│ :1518 in _wrapped_call_impl │
│ │
│ 1515 │ │ if self._compiled_call_impl is not None: │
│ 1516 │ │ │ return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] │
│ 1517 │ │ else: │
│ ❱ 1518 │ │ │ return self._call_impl(*args, **kwargs) │
│ 1519 │ │
│ 1520 │ def _call_impl(self, *args, **kwargs): │
│ 1521 │ │ forward_call = (self._slow_forward if torch._C._get_tracing_state() else self.fo │
│ │
│ /nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py │
│ :1568 in _call_impl │
│ │
│ 1565 │ │ │ │ bw_hook = hooks.BackwardHook(self, full_backward_hooks, backward_pre_hoo │
│ 1566 │ │ │ │ args = bw_hook.setup_input_hook(args) │
│ 1567 │ │ │ │
│ ❱ 1568 │ │ │ result = forward_call(*args, **kwargs) │
│ 1569 │ │ │ if _global_forward_hooks or self._forward_hooks: │
│ 1570 │ │ │ │ for hook_id, hook in ( │
│ 1571 │ │ │ │ │ *_global_forward_hooks.items(), │
│ │
│ /nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/linear.py │
│ :114 in forward │
│ │
│ 111 │ │ │ init.uniform_(self.bias, -bound, bound) │
│ 112 │ │
│ 113 │ def forward(self, input: Tensor) -> Tensor: │
│ ❱ 114 │ │ return F.linear(input, self.weight, self.bias) │
│ 115 │ │
│ 116 │ def extra_repr(self) -> str: │
│ 117 │ │ return f'in_features={self.in_features}, out_features={self.out_features}, bias= │
│ │
│ /nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/deepspeed/runtime/zero/lin │
│ ear.py:109 in zero3_linear_wrap │
│ │
│ 106 │
│ 107 def zero3_linear_wrap(input, weight, bias=None): │
│ 108 │ if bias is None: │
│ ❱ 109 │ │ return LinearFunctionForZeroStage3.apply(input, weight) │
│ 110 │ else: │
│ 111 │ │ return LinearFunctionForZeroStage3.apply(input, weight, bias) │
│ 112 │
│ │
│ /nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/autograd/function.py │
│ :539 in apply │
│ │
│ 536 │ │ if not torch._C._are_functorch_transforms_active(): │
│ 537 │ │ │ # See NOTE: [functorch vjp and autograd interaction] │
│ 538 │ │ │ args = _functorch.utils.unwrap_dead_wrappers(args) │
│ ❱ 539 │ │ │ return super().apply(*args, **kwargs) # type: ignore[misc] │
│ 540 │ │ │
│ 541 │ │ if cls.setup_context == _SingleLevelFunction.setup_context: │
│ 542 │ │ │ raise RuntimeError( │
│ │
│ /nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/cuda/amp/autocast_mo │
│ de.py:113 in decorate_fwd │
│ │
│ 110 │ │ args[0]._dtype = torch.get_autocast_gpu_dtype() │
│ 111 │ │ if cast_inputs is None: │
│ 112 │ │ │ args[0]._fwd_used_autocast = torch.is_autocast_enabled() │
│ ❱ 113 │ │ │ return fwd(*args, **kwargs) │
│ 114 │ │ else: │
│ 115 │ │ │ autocast_context = torch.is_autocast_enabled() │
│ 116 │ │ │ args[0]._fwd_used_autocast = False │
│ │
│ /nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/deepspeed/runtime/zero/lin │
│ ear.py:57 in forward │
│ │
│ 54 │ │ │ # fused op is marginally faster │
│ 55 │ │ │ ret = torch.addmm(bias, input, weight.t()) │
│ 56 │ │ else: │
│ ❱ 57 │ │ │ output = input.matmul(weight.t()) │
│ 58 │ │ │ if bias is not None: │
│ 59 │ │ │ │ output += bias │
│ 60 │ │ │ ret = output │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
RuntimeError: expected mat1 and mat2 to have the same dtype, but got: c10::BFloat16 != c10::Half
Zero config:
{
"fp16": {
"enabled": false
},
"bf16": {
"enabled": true
},
"zero_optimization": {
"stage": 3,
"offload_optimizer": {
"device": "none",
"nvme_path": "None"
},
"offload_param": {
"device": "none",
"nvme_path": "None"
},
"stage3_gather_16bit_weights_on_model_save": true,
"reduce_bucket_size": "auto",
"zero_quantized_weights": true,
"zero_hpz_partition_size": 2,
"zero_quantized_gradients": true,
"contiguous_gradients": true,
"overlap_comm": true
},
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"steps_per_print": "inf"
}
Accelerate config:
compute_environment: LOCAL_MACHINE
debug: false
deepspeed_config:
# deepspeed_multinode_launcher: standard
# offload_optimizer_device: none
# offload_param_device: none
zero3_init_flag: true
# zero3_save_16bit_model: true
# zero_stage: 3
deepspeed_config_file: ./zero_configs/zero3++.json
distributed_type: DEEPSPEED
downcast_bf16: 'no'
machine_rank: 0
main_training_function: main
#mixed_precision: bf16
num_machines: 1
#num_processes: 8
num_processes: 2
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
Trainer on 2 A5000 GPUs.
Hi Currently zero++ feature does not support for bf16 quantization, I suppose that is the root cause of this issue.
To fix it, you can
Either use fp16 as dtype
Or make "zero_quantized_weights": false and zero_quantized_gradients": false
@GuanhuaWang But because of some other training stability issues like this related to initializing llama in fp16 this makes training with zero++ for llama quite troublesome. Should we maybe reopen this issue and see about supporting bf16 in zero++?