Llama Model throwing "RuntimeError: expected scalar type BFloat16 but found Float" when using torch.compile and AMP together
System Info
transformers 4.41.0 torch 2.3.0 GPU: NVIDIA GeForce RTX 4090, CUDA version 12.3
Who can help?
No response
Information
- [ ] The official example scripts
- [X] My own modified scripts
Tasks
- [ ] An officially supported task in the
examplesfolder (such as GLUE/SQuAD, ...) - [X] My own task or dataset (give details below)
Reproduction
import torch
from transformers import LlamaConfig, LlamaForCausalLM, AdamW, AutoModelForCausalLM, GPT2Config
from torch.cuda.amp import autocast, GradScaler
# Configure the model
config = LlamaConfig(
num_attention_heads=6,
num_hidden_layers=6,
hidden_size=384,
intermediate_size=1536, # Typically 4 * hidden_size
vocab_size=30522, # Standard vocabulary size
max_position_embeddings=1024,
)
# config = GPT2Config(
# n_embd=384,
# n_head=6,
# n_layer=6,
# n_positions=1024,
# n_ctx=1024,
# n_vocab=30522,
# )
# Initialize the model
model = AutoModelForCausalLM.from_config(config, attn_implementation="eager").to('cuda')
# Compile the model (Torch 2.0 and above)
model = torch.compile(model)
# Create dummy data
batch_size = 8
sequence_length = 1024
dummy_input_ids = torch.randint(0, config.vocab_size, (batch_size, sequence_length)).to('cuda')
dummy_labels = torch.randint(0, config.vocab_size, (batch_size, sequence_length)).to('cuda')
# Set up the optimizer
optimizer = AdamW(model.parameters(), lr=1e-4)
scaler = GradScaler()
# Set the model to training mode
model.train()
# Training loop
num_epochs = 10000
for epoch in range(num_epochs):
with autocast(dtype=torch.bfloat16, enabled=True):
# Forward pass
outputs = model(input_ids=dummy_input_ids, labels=dummy_labels)
loss = outputs.loss
# Backward pass
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {loss.item()}")
print("Training complete.")
Expected behavior
Running the code snippet above gives me the following error
{
"name": "RuntimeError",
"message": "expected scalar type BFloat16 but found Float",
"stack": "---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[2], line 54
51 loss = outputs.loss
53 # Backward pass
---> 54 scaler.scale(loss).backward()
55 scaler.step(optimizer)
56 scaler.update()
File ~/anaconda3/lib/python3.11/site-packages/torch/_tensor.py:525, in Tensor.backward(self, gradient, retain_graph, create_graph, inputs)
515 if has_torch_function_unary(self):
516 return handle_torch_function(
517 Tensor.backward,
518 (self,),
(...)
523 inputs=inputs,
524 )
--> 525 torch.autograd.backward(
526 self, gradient, retain_graph, create_graph, inputs=inputs
527 )
File ~/anaconda3/lib/python3.11/site-packages/torch/autograd/__init__.py:267, in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
262 retain_graph = create_graph
264 # The reason we repeat the same comment below is that
265 # some Python versions print out the first line of a multi-line function
266 # calls in the traceback and some print out the last line
--> 267 _engine_run_backward(
268 tensors,
269 grad_tensors_,
270 retain_graph,
271 create_graph,
272 inputs,
273 allow_unreachable=True,
274 accumulate_grad=True,
275 )
File ~/anaconda3/lib/python3.11/site-packages/torch/autograd/graph.py:744, in _engine_run_backward(t_outputs, *args, **kwargs)
742 unregister_hooks = _register_logging_hooks_on_whole_graph(t_outputs)
743 try:
--> 744 return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
745 t_outputs, *args, **kwargs
746 ) # Calls into the C++ engine to run the backward pass
747 finally:
748 if attach_logging_hooks:
File ~/anaconda3/lib/python3.11/site-packages/torch/autograd/function.py:301, in BackwardCFunction.apply(self, *args)
295 raise RuntimeError(
296 \"Implementing both 'backward' and 'vjp' for a custom \"
297 \"Function is not allowed. You should only implement one \"
298 \"of them.\"
299 )
300 user_fn = vjp_fn if vjp_fn is not Function.vjp else backward_fn
--> 301 return user_fn(self, *args)
File ~/anaconda3/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:882, in aot_dispatch_autograd.<locals>.CompiledFunction.backward(ctx, *flat_args)
880 out = CompiledFunctionBackward.apply(*all_args)
881 else:
--> 882 out = call_compiled_backward()
884 # TODO: figure out how to refactor the backward properly so I can use aot_dispatch_subclass_wrapper() here.
885 if CompiledFunction.maybe_subclass_metadata is not None:
File ~/anaconda3/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:831, in aot_dispatch_autograd.<locals>.CompiledFunction.backward.<locals>.call_compiled_backward()
824 with tracing(saved_context), context(), track_graph_compiling(
825 aot_config, \"backward\"
826 ):
827 CompiledFunction.compiled_bw = aot_config.bw_compiler(
828 bw_module, placeholder_list
829 )
--> 831 out = call_func_at_runtime_with_args(
832 CompiledFunction.compiled_bw,
833 all_args,
834 steal_args=True,
835 disable_amp=disable_amp,
836 )
838 out = functionalized_rng_runtime_epilogue(
839 CompiledFunction.metadata, out
840 )
841 return tuple(out)
File ~/anaconda3/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/utils.py:113, in call_func_at_runtime_with_args(f, args, steal_args, disable_amp)
111 with context():
112 if hasattr(f, \"_boxed_call\"):
--> 113 out = normalize_as_list(f(args))
114 else:
115 # TODO: Please remove soon
116 # https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670
117 warnings.warn(
118 \"Your compiler for AOTAutograd is returning a function that doesn't take boxed arguments. \"
119 \"Please wrap it with functorch.compile.make_boxed_func or handle the boxed arguments yourself. \"
120 \"See https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670 for rationale.\"
121 )
File ~/anaconda3/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py:451, in _TorchDynamoContext.__call__.<locals>._fn(*args, **kwargs)
449 prior = set_eval_frame(callback)
450 try:
--> 451 return fn(*args, **kwargs)
452 finally:
453 set_eval_frame(prior)
File ~/anaconda3/lib/python3.11/site-packages/torch/_dynamo/external_utils.py:36, in wrap_inline.<locals>.inner(*args, **kwargs)
34 @functools.wraps(fn)
35 def inner(*args, **kwargs):
---> 36 return fn(*args, **kwargs)
File ~/anaconda3/lib/python3.11/site-packages/torch/_inductor/codecache.py:906, in CompiledFxGraph.__call__(self, inputs)
905 def __call__(self, inputs: List[Any]) -> Any:
--> 906 return self.get_current_callable()(inputs)
File ~/anaconda3/lib/python3.11/site-packages/torch/_inductor/compile_fx.py:784, in align_inputs_from_check_idxs.<locals>.run(new_inputs)
782 def run(new_inputs):
783 copy_misaligned_inputs(new_inputs, inputs_to_check)
--> 784 return model(new_inputs)
File ~/anaconda3/lib/python3.11/site-packages/torch/_inductor/codecache.py:934, in _run_from_cache(compiled_graph, inputs)
926 assert compiled_graph.artifact_path
927 compiled_graph.compiled_artifact = PyCodeCache.load_by_key_path(
928 compiled_graph.cache_key,
929 compiled_graph.artifact_path,
930 compiled_graph.cache_linemap,
931 compiled_graph.constants,
932 ).call
--> 934 return compiled_graph.compiled_artifact(inputs)
File /tmp/torchinductor_zcai75/wq/cwqm67koqia7gthn65wgmhppfzrfyheocl4px7fecurpkfigigfs.py:1751, in call(args)
1749 buf39 = reinterpret_tensor(buf34, (48, 64, 1024), (65536, 1024, 1), 0); del buf34 # reuse
1750 # Source Nodes: [], Original ATen: [aten.bmm]
-> 1751 extern_kernels.bmm(permute_103, reinterpret_tensor(buf38, (48, 1024, 1024), (1048576, 1024, 1), 0), out=buf39)
1752 del permute_103
1753 buf41 = empty_strided_cuda((8, 6, 1024, 64), (393216, 65536, 64, 1), torch.bfloat16)
RuntimeError: expected scalar type BFloat16 but found Float"
}
This problem does not seem to happen for a GPT2 model. If I initialize the GPT2Config instead of LlamaConfig in the commented code in the script, there is no such error.
cc @ArthurZucker
Hi @JackCai1206 I ran your script but didn't encounter the error that you mentioned for LlamaConfig and ran smoothly for both. Can you check your pytorch cuda compatibility as I have a version 12.2 with pytorch 2.3 (PyTorch version (GPU?): 2.3.0+cu121 (True), Cuda compilation tools, release 12.2, V12.2.140)?
When I run nvidia-smi I get | NVIDIA-SMI 545.23.08 Driver Version: 545.23.08 CUDA Version: 12.3 |
and i have installed torch 2.3.0 without "cu" suffixes, which I assume is compatible with cuda 12?
@JackCai1206 There are two main APIs of CUDA, the runtime and the driver. The nvidia CUDA version you have posted is the driver API version and what we have with pytorch is the runtime API one which we get after cuda toolkit gets installed automatically with pip3 install torch
Just for confirmation can you check the output of pip list | grep torch and torch.version.cuda. If the outputs does show no cuda dependencies and None respectively then we have to reinstall pytorch with cuda dependencies.
Hi, thanks for the explanation! This is the output of pip list
torch 2.3.0
torchaudio 2.3.0
torchvision 0.18.0
and torch cuda version
>>> import torch
>>> torch.version.cuda
'12.1'
also cc @gante
Hi, thanks for the explanation! This is the output of pip list
torch 2.3.0 torchaudio 2.3.0 torchvision 0.18.0and torch cuda version
>>> import torch >>> torch.version.cuda '12.1'
@JackCai1206 Oh! I see. What i found could be the reason for the error is this line in modeling_llama as your model has (rotary_emb): LlamaRotaryEmbedding(). It forces float32 as bfloat16 loses precision on long context.
If you want to use autocast then an alternative trial could be to use Trainer class of transformers and activate autocast through bf16=True argument in TrainingArguments
Sounds good. Yeah i think a warning message there could be useful.
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.