sdpa_ex - Incorrect device in trace vs from actual computation
sdpa_ex implementation of torch.nn.functional.scaled_dot_product_attention returns all output tensor proxy in trace to be on cuda but at runtime some outputs are on cpu.
Repro
import torch
import thunder
torch.backends.cuda.enable_cudnn_sdp(False)
torch.backends.cuda.enable_mem_efficient_sdp(True)
def fn(q, k, v):
return torch.nn.functional.scaled_dot_product_attention(q, k, v, None, 0.0, is_causal=True, scale=0.08838834764831843)
q = torch.randn(8, 32, 1024, 128, device='cuda', requires_grad=True)
k = torch.randn(8, 32, 1024, 128, device='cuda', requires_grad=True)
t = torch.randn(8, 1024, 12288, device='cuda', requires_grad=False)
v = torch.as_strided(t, (8, 32, 1024, 128), ((12582912, 384, 12288, 1)))
jfn = thunder.jit(fn)
o = jfn(q, k, v)
extrace = thunder.last_traces(jfn)[-1]
print(extrace)
saved_tensors_trace = extrace.bound_symbols[4].args[1][0]
print("TRACE OUTPUT DEVICES", list(t.device.device_str() for t in saved_tensors_trace))
cache_entry, inps, pro_to_epi = jfn._lc_cd._get_computation_and_inputs(q, k, v)
output = cache_entry.computation_fn(*inps)
data_for_autograd, (saved_tensors, saved_other) = output
print("ACTUAL SAVED DEVICES", list(str(o.device) for o in saved_tensors))
Output
TRACE OUTPUT DEVICES ['cuda:0', 'cuda:0', 'cuda:0', 'cuda:0', 'cuda:0', 'cuda:0', 'cuda:0']
ACTUAL SAVED DEVICES ['cuda:0', 'cuda:0', 'cuda:0', 'cuda:0', 'cpu', 'cpu', 'cuda:0']
Trace
# Constructed by Delete Last Used (took 0 milliseconds)
import torch
from thunder.executors.torchex import no_autocast
@torch.no_grad()
@no_autocast
def augmented_forward_fn(q, k, v):
# q: "cuda:0 f32[8, 32, 1024, 128]"
# k: "cuda:0 f32[8, 32, 1024, 128]"
# v: "cuda:0 f32[8, 32, 1024, 128]"
(t0, t1, t2, t3) = sdpaex_grad_forward_scaled_dot_product_efficient_attention(q, k, v, None, 0.0, True, 0.08838834764831843)
return {'output': t0, 'flat_args': [q, k, v], 'flat_output': (t0,)}, ((k, q, t0, t1, t2, t3, v), ())
cc @carmocca
@rdspring1 are you free to take a look at this?
Per triage meeting, assigned to @IvanYashchuk to find right owner
@kshitij12345, how did you discover this problem? Is this behavior blocking anything?
The device of seed and offset depends on whether PyTorch is in CUDA Stream-capturing mode or not (Stream capturing is used for CUDA Graphs). Under stream capturing the device for these tensors is CUDA (as in Thunder's meta function) and in the usual mode, the device is CPU which causes the discrepancy as in the provided code example. Link to PyTorch code that decides what device to use: https://github.com/pytorch/pytorch/blame/32f3af72b7760f883d9cc1a09b0599da3652d80c/aten/src/ATen/native/transformers/cuda/attention.cu#L1077-L1088
Thunder could detect that the CUDA Stream is in capturing mode and add this to Thunder's cache or compile data to be queried from meta functions. @t-vi, what do you think about this?
how did you discover this problem? Is this behavior blocking anything?
This was found while I was working on CPU Offloading tutorial where it expected these tensors to be on GPU as returned by the thunder's meta. This is not blocking anything.
Thanks for digging into this.