In-place operations in triton kernel might result in incorrect gradient calculations
🐛 Describe the bug
#254 #262 (comments)
PyTorch’s autograd system records operations on tensors to construct a computational graph, which is used for computing gradients. When an in-place operation is performed on a tensor, the autograd system needs to ensure that the computational graph reflects the modified values.
https://pytorch.org/docs/stable/autograd.html#in-place-correctness-checks
Each tensor in PyTorch has an internal version counter that is incremented every time an in-place operation is performed.
https://github.com/pytorch/pytorch/blob/190e09d8b6a13f789b143f0fbd1325f924550967/c10/core/TensorImpl.h#L382
Since we don't explicitly call pytorch in-place operations, the version counter doesn't change when we do in-place operations in triton kernels, i.e., pytorch's "In-place correctness checks" mechanism won't work properly and show no error to user.
Reproduce
import torch
import torch.nn.functional as F
from liger_kernel.transformers.functional import liger_cross_entropy
def run_inplace_experiment(logits_p, logits_q, cross_entropy_fn):
_p = logits_p.clone().detach().requires_grad_(True)
_p.retain_grad()
softmax = torch.nn.Softmax(dim=-1)
p = softmax(_p)
p.retain_grad()
loss = cross_entropy_fn(p, logits_q)
loss.backward(retain_graph=True)
print(f"Cross Entropy Loss: {loss.item()}")
print(f"Input _p: {_p}")
print(f"Input logits_q: {logits_q}")
print(f"Gradients of p (batch item 0): {p.grad[0]}")
print(f"Gradients of _p (batch item 0): {_p.grad[0]}")
torch.manual_seed(0)
logits_p = torch.randn(8, 8, requires_grad=True, device="cuda")
logits_q = torch.randint(0, 8, (8,), device="cuda", dtype=torch.long)
run_inplace_experiment(logits_p, logits_q, cross_entropy_fn=F.cross_entropy)
print()
print("LIGER:")
run_inplace_experiment(logits_p, logits_q, cross_entropy_fn=liger_cross_entropy)
❯ python3 inplace_bug.py
Cross Entropy Loss: 2.08567214012146
Input _p: tensor([[-0.9247, -0.4253, -2.6438, 0.1452, -0.1209, -0.5797, -0.6229, -0.3284],
[-1.0745, -0.3631, -1.6711, 2.2655, 0.3117, -0.1842, 1.2866, 1.1820],
[-0.1271, 1.2169, 1.4353, 1.0605, -0.4941, -1.4244, -0.7244, -1.2973],
[ 0.0697, -0.0074, 1.8969, 0.6878, -0.0779, -0.8373, 1.3506, -0.2879],
[-0.5965, -0.3283, -0.9086, -0.8059, -0.7407, -0.0504, 0.5435, 1.5150],
[ 0.0141, 0.4532, 1.6349, 0.7124, -0.1806, 1.0252, -1.4622, -0.7554],
[-0.1836, 0.3824, 0.3918, -0.0830, 0.8971, -1.1123, 0.1116, 0.4863],
[-0.5499, -0.3231, -0.5469, 0.9049, 0.2837, 0.1210, 0.4730, -1.0823]],
device='cuda:0', requires_grad=True)
Input logits_q: tensor([4, 6, 7, 2, 2, 6, 5, 5], device='cuda:0')
Gradients of p (batch item 0): tensor([ 0.0149, 0.0157, 0.0140, 0.0174, -0.1086, 0.0154, 0.0153, 0.0159],
device='cuda:0')
Gradients of _p (batch item 0): tensor([ 0.0017, 0.0029, 0.0003, 0.0055, -0.0182, 0.0024, 0.0023, 0.0032],
device='cuda:0')
LIGER:
Cross Entropy Loss: 2.08567214012146
Input _p: tensor([[-0.9247, -0.4253, -2.6438, 0.1452, -0.1209, -0.5797, -0.6229, -0.3284],
[-1.0745, -0.3631, -1.6711, 2.2655, 0.3117, -0.1842, 1.2866, 1.1820],
[-0.1271, 1.2169, 1.4353, 1.0605, -0.4941, -1.4244, -0.7244, -1.2973],
[ 0.0697, -0.0074, 1.8969, 0.6878, -0.0779, -0.8373, 1.3506, -0.2879],
[-0.5965, -0.3283, -0.9086, -0.8059, -0.7407, -0.0504, 0.5435, 1.5150],
[ 0.0141, 0.4532, 1.6349, 0.7124, -0.1806, 1.0252, -1.4622, -0.7554],
[-0.1836, 0.3824, 0.3918, -0.0830, 0.8971, -1.1123, 0.1116, 0.4863],
[-0.5499, -0.3231, -0.5469, 0.9049, 0.2837, 0.1210, 0.4730, -1.0823]],
device='cuda:0', requires_grad=True)
Input logits_q: tensor([4, 6, 7, 2, 2, 6, 5, 5], device='cuda:0')
Gradients of p (batch item 0): tensor([ 0.0149, 0.0157, 0.0140, 0.0174, -0.1086, 0.0154, 0.0153, 0.0159],
device='cuda:0')
Gradients of _p (batch item 0): tensor([2.1320e-05, 3.4830e-05, 6.8024e-06, 6.7467e-05, 1.3247e-02, 2.9687e-05,
2.8429e-05, 3.8656e-05], device='cuda:0')
Solution
One trivial solution is performing a no-op like inplace operation, such as .add_(0) and .mul_(1), to explicitly declare we have changed the tensor values in-place, then the errors will be raised.
With this approach, I suggest adding a inplace=True/False parameter to those functions involved in-place operations, so users can set it to False (using extra tensors) when getting errors.
Versions
Environment Report:
Operating System: Linux-5.15.133.1-microsoft-standard-WSL2-x86_64-with-glibc2.35 Python version: 3.10.12 PyTorch version: 2.4.1+cu121 CUDA version: 12.1 Triton version: 3.0.0 Transformers version: 4.45.0
should we adopt the second solution since the first one introduces quite a lot of overhead? also, can you elaborate under which case will this behavior happen?
@ByronHsu
also, can you elaborate under which case will this behavior happen?
Consider the following forward graph:
graph TD
A[input] -->|a| B[exp]
B -->|b| C[liger_ce]
C -->|loss| ouput
to calculate gradients of exp layer, which is exp(input), we can either:
- save input tensor
ain forward pass, then recomputeexp(a)in backward pass - save output tensor
bin forward pass, no need further operations in backward pass (assum torch marks it as version 0)
Normally, we take the least computations/memory option, 2. in this case.
graph TD
A[input] -->|a| B["exp <br> saved tensors: b (v0)"]
B -->|b| C[liger_ce]
C -->|loss| ouput
After a complete forward pass from input a to loss, now we call loss.backward().
graph TD
A[input] <-->|dx * grad_ce = b' * grad_ce| B["exp <br> saved tensors: b' (v0)<br>(changed by liger_ce)"]
B <-->|grad_ce| C[liger_ce]
C <-->|loss| ouput
Notice that in forward pass we stored the gradients of liger_ce at b, the input tensor of it, so the saved tensor b in exp layer has been changed as well. Since the saved tensor is corrupted, exp layer can't produce the correct gradients.
Replacing exp with any layer that stores output tensor and liger_ce with any layer that performs inplace operations on input, will result in the same behavior.
tl;dr The saved tensors are corrupted by inplace operations.
Why no error?
The reason why it doesn't raise the error is because triton kernel doesn't bump the version when doing inplace op, so it's still v0 when computing gradients in backward.
If we do inplace outside of kernel by calling torch function, version can be correctly updated.
graph TD
A[input] <-->|"dx * grad_output <br>= b' * grad_output"| B["exp <br> saved tensors: b' (v1)<br>(changed by inplace op)"]
B <-->|grad_output| C["torch's inplace op"]
C <-->|something| something
Thus, the error can be detected.
We can keep pointers of gradients when designing a kernel, and add a boolean argument to autograd.function for users to decide whether storing gradients inplace or not.
If False, we can allocate new memory and pass it to kernel. E.g. X_ptr and dX_ptr as below:
https://github.com/linkedin/Liger-Kernel/blob/ff6650bbcef5d31b7522694cbeb73a21169460e9/src/liger_kernel/ops/jsd.py#L64-L77
If True, we can just pass the existing tensor that we want to perform in-place storing. E.g. X_ptr and dX_ptr as below:
https://github.com/linkedin/Liger-Kernel/blob/ff6650bbcef5d31b7522694cbeb73a21169460e9/src/liger_kernel/ops/fused_linear_jsd.py#L75-L88
Above examples show that we can design a kernel which looks "out-place" but still can achieve "in-place" storing.
One trivial solution is performing a no-op like inplace operation, such as .add_(0) and .mul_(1), to explicitly declare we have changed the tensor values in-place, then the errors will be raised.
Since the trivial solution introduces quite a lot of overhead, we can just do it only in the first pass as a in-place correctness checker.
A possible implementation could be like this:
@triton.jit
def _kernel(
x_ptr, # input tensor
y_ptr, # output tensor
dx_ptr, # gradients of input
...
):
... # do something
def forward(_input, inplace: bool, ...):
... # do something
if inplace:
dx = _input
if first_pass: # I haven't come up with a good way to detect first pass or not
_input.add_(0)
else:
dx = tensor.zeros_like(_input)
_kernel[(...)](
x_ptr=_input,
y_ptr=output,
dx_ptr=dx,
...
)
return output
cc @ByronHsu @lancerts