lightning-thunder icon indicating copy to clipboard operation
lightning-thunder copied to clipboard

uniform_like: The function outputs different values when the input tensor is the same but `requires_grad` is True/False.

Open kiya00 opened this issue 1 year ago • 2 comments

🐛 Bug

The same function outputs different values when the input tensor is the same but requires_grad is True/False.

note: if change the last line in func to be return f+d, the outputs are the same as expected. torchex doesn't have the problem

import torch
import thunder

def func(a):
    b = thunder.torch.uniform_like(a, device=a.device, dtype=a.dtype)
    e = a * b
    c = thunder.torch.uniform_like(a, device=a.device, dtype=a.dtype)
    f = e + c
    d = thunder.torch.uniform_like(a, device=a.device, dtype=a.dtype)
    return f * d      # output different results when `a` requires or not requires grad
    # return f + d    # output the expected same results

a = torch.randn(2, 2, device='cuda')
a1 = a.detach().clone().requires_grad_()

cuda_generator = torch.cuda.default_generators[0]
cuda_generator.manual_seed(20)
# print(cuda_generator.get_state())
jfunc = thunder.jit(func, executors_list=[thunder.nvfuser_executor])
out = jfunc(a)

cuda_generator.manual_seed(20)
# print(cuda_generator.get_state())
jfunc = thunder.jit(func, executors_list=[thunder.nvfuser_executor])
out1 = jfunc(a1)
torch.testing.assert_close(out, out1)

Traces:

# Constructed by Delete Last Used (took 0 milliseconds)
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def computation(a):
  # a: "cuda:0 f32[2, 2]"
  [t5] = nvFusion0(a)
    # b = prims.uniform((2, 2), 0.0, 1.0, device=devices.Device("cuda:0"), dtype=dtypes.float32)  # b: "cuda:0 f32[2, 2]"
    # result = prims.mul(a, b)  # result: "cuda:0 f32[2, 2]"
    # c = prims.uniform((2, 2), 0.0, 1.0, device=devices.Device("cuda:0"), dtype=dtypes.float32)  # c: "cuda:0 f32[2, 2]"
    # f = prims.add(result, c)  # f: "cuda:0 f32[2, 2]"
    # d = prims.uniform((2, 2), 0.0, 1.0, device=devices.Device("cuda:0"), dtype=dtypes.float32)  # d: "cuda:0 f32[2, 2]"
    # t5 = prims.mul(f, d)  # t5: "cuda:0 f32[2, 2]"
  del a
  return t5

# Constructed by Delete Last Used (took 0 milliseconds)
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def augmented_forward_fn(a):
  # a: "cuda:0 f32[2, 2]"
  [t0, t4, t5] = nvFusion0(a)
    # t0 = prims.uniform((2, 2), 0.0, 1.0, device=devices.Device("cuda:0"), dtype=dtypes.float32)  # t0: "cuda:0 f32[2, 2]"
    # t1 = prims.mul(a, t0)  # t1: "cuda:0 f32[2, 2]"
    # t2 = prims.uniform((2, 2), 0.0, 1.0, device=devices.Device("cuda:0"), dtype=dtypes.float32)  # t2: "cuda:0 f32[2, 2]"
    # t3 = prims.add(t1, t2)  # t3: "cuda:0 f32[2, 2]"
    # t4 = prims.uniform((2, 2), 0.0, 1.0, device=devices.Device("cuda:0"), dtype=dtypes.float32)  # t4: "cuda:0 f32[2, 2]"
    # t5 = prims.mul(t3, t4)  # t5: "cuda:0 f32[2, 2]"
  return {'output': t5, 'flat_args': [a], 'flat_output': (t5,)}, ((t0, t4), ())

cc @apaz-cli

kiya00 avatar May 06 '24 19:05 kiya00

triage review — we should review this as part of a larger reproducible randomness discussion

Trying to produce the same random numbers as PyTorch is probably a non-goal. Trying to produce the same random numbers regardless of executor might be a goal.

mruberry avatar May 13 '24 19:05 mruberry

Both of these results are with nvFuser, the issue being that segmentation and ordering of the runtime could depend on what is or is not an output. So the question seems to be if we should try to guarantee if the only delta between a nvFuser graphs is the marked outputs, should we generate the same RNG per tensor.

csarofeen avatar May 15 '24 23:05 csarofeen