lightning-thunder icon indicating copy to clipboard operation
lightning-thunder copied to clipboard

sdpa_ex - Incorrect device in trace vs from actual computation

Open kshitij12345 opened this issue 1 year ago • 4 comments

sdpa_ex implementation of torch.nn.functional.scaled_dot_product_attention returns all output tensor proxy in trace to be on cuda but at runtime some outputs are on cpu.

Repro

import torch
import thunder

torch.backends.cuda.enable_cudnn_sdp(False)
torch.backends.cuda.enable_mem_efficient_sdp(True)

def fn(q, k, v):
    return torch.nn.functional.scaled_dot_product_attention(q, k, v, None, 0.0, is_causal=True, scale=0.08838834764831843)

q = torch.randn(8, 32, 1024, 128, device='cuda', requires_grad=True)
k = torch.randn(8, 32, 1024, 128, device='cuda', requires_grad=True)
t = torch.randn(8, 1024, 12288, device='cuda', requires_grad=False)
v = torch.as_strided(t, (8, 32, 1024, 128), ((12582912, 384, 12288, 1)))

jfn = thunder.jit(fn)
o = jfn(q, k, v)

extrace = thunder.last_traces(jfn)[-1]
print(extrace)

saved_tensors_trace = extrace.bound_symbols[4].args[1][0]
print("TRACE OUTPUT DEVICES", list(t.device.device_str() for t in saved_tensors_trace))    

cache_entry, inps, pro_to_epi = jfn._lc_cd._get_computation_and_inputs(q, k, v)
output = cache_entry.computation_fn(*inps)
data_for_autograd, (saved_tensors, saved_other) = output
print("ACTUAL SAVED DEVICES", list(str(o.device) for o in saved_tensors))

Output

TRACE OUTPUT DEVICES ['cuda:0', 'cuda:0', 'cuda:0', 'cuda:0', 'cuda:0', 'cuda:0', 'cuda:0']
ACTUAL SAVED DEVICES ['cuda:0', 'cuda:0', 'cuda:0', 'cuda:0', 'cpu', 'cpu', 'cuda:0']

Trace

# Constructed by Delete Last Used (took 0 milliseconds)
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def augmented_forward_fn(q, k, v):
  # q: "cuda:0 f32[8, 32, 1024, 128]"
  # k: "cuda:0 f32[8, 32, 1024, 128]"
  # v: "cuda:0 f32[8, 32, 1024, 128]"
  (t0, t1, t2, t3) = sdpaex_grad_forward_scaled_dot_product_efficient_attention(q, k, v, None, 0.0, True, 0.08838834764831843)
  return {'output': t0, 'flat_args': [q, k, v], 'flat_output': (t0,)}, ((k, q, t0, t1, t2, t3, v), ())

cc @carmocca

kshitij12345 avatar Aug 09 '24 22:08 kshitij12345

@rdspring1 are you free to take a look at this?

mruberry avatar Aug 19 '24 15:08 mruberry

Per triage meeting, assigned to @IvanYashchuk to find right owner

nvMelissa avatar Aug 26 '24 18:08 nvMelissa

@kshitij12345, how did you discover this problem? Is this behavior blocking anything?

The device of seed and offset depends on whether PyTorch is in CUDA Stream-capturing mode or not (Stream capturing is used for CUDA Graphs). Under stream capturing the device for these tensors is CUDA (as in Thunder's meta function) and in the usual mode, the device is CPU which causes the discrepancy as in the provided code example. Link to PyTorch code that decides what device to use: https://github.com/pytorch/pytorch/blame/32f3af72b7760f883d9cc1a09b0599da3652d80c/aten/src/ATen/native/transformers/cuda/attention.cu#L1077-L1088

Thunder could detect that the CUDA Stream is in capturing mode and add this to Thunder's cache or compile data to be queried from meta functions. @t-vi, what do you think about this?

IvanYashchuk avatar Sep 07 '24 06:09 IvanYashchuk

how did you discover this problem? Is this behavior blocking anything?

This was found while I was working on CPU Offloading tutorial where it expected these tensors to be on GPU as returned by the thunder's meta. This is not blocking anything.

Thanks for digging into this.

kshitij12345 avatar Sep 07 '24 09:09 kshitij12345