lightning-thunder icon indicating copy to clipboard operation
lightning-thunder copied to clipboard

Reducing the weight of NumberProxy used in mincut in rematerialization

Open kiya00 opened this issue 1 year ago • 3 comments

Before submitting
  • [ ] Was this discussed/approved via a Github issue? (no need for typos and docs improvements)
  • [ ] Did you read the contributor guideline, Pull Request section?
  • [ ] Did you make sure to update the docs?
  • [ ] Did you write any new necessary tests?

What does this PR do?

Fixes part of #114 .

Background: When decomposing dropout to uniform_philox, the rematerialization doesn't pass the expected seed/offset to the backward

Trace before this PR(based on branch uniform_rng):

@torch.no_grad()
@no_autocast
def augmented_forward_fn(a):
  # a: "cuda:0 f32[2, 2]"
  t6 = get_rng_state_prim_impl(None, devices.Device("cuda:0"))  # t6: "cpu ui8[16]"
  (i7, i8) = unpack_rng_state_prim_impl(t6)
  [t1, t5] = nvFusion0(a, i7, i8)
    # t0 = prims.uniform_philox((2, 2), 0.0, 1.0, device=devices.Device("cuda:0"), dtype=dtypes.float32, seed=i7, offset=i8)  # t0: "cuda:0 f32[2, 2]"
    # t1 = prims.lt(t0, 0.5)  # t1: "cuda:0 b8[2, 2]"
    # t2 = prims.convert_element_type(t1, dtypes.float32)  # t2: "cuda:0 f32[2, 2]"
    # t3 = prims.mul(a, t2)  # t3: "cuda:0 f32[2, 2]"
    # t4 = prims.mul(t3, 2.0)  # t4: "cuda:0 f32[2, 2]"
    # t5 = prims.mul(a, t4)  # t5: "cuda:0 f32[2, 2]"
  t10 = update_rng_state_prim_impl(i7, i8)  # t10: "cpu ui8[16]"
  set_rng_state_prim_impl(t10, devices.Device("cuda:0"))
  return {'output': t5, 'flat_args': [a], 'flat_output': (t5,)}, ((a, t1), (2.0,))
# Constructed by Update Call Context (took 0 milliseconds)
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def backward_fn(saved_for_backward, cotangents):
  # saved_for_backward: "Collection"
  # cotangents: "Collection"
  C0, C1, = saved_for_backward
  t6, = cotangents
  a, t1, = C0
  f7, = C1
  [t39] = nvFusion0(a, f7, t1, t6)
    # t2 = prims.convert_element_type(t1, dtypes.float32)  # t2: "cuda:0 f32[2, 2]"
    # t3 = prims.mul(a, t2)  # t3: "cuda:0 f32[2, 2]"
    # t4 = prims.mul(t3, 2.0)  # t4: "cuda:0 f32[2, 2]"
    # t33 = prims.mul(t4, t6)  # t33: "cuda:0 f32[2, 2]"
    # t34 = prims.mul(a, t6)  # t34: "cuda:0 f32[2, 2]"
    # t35 = prims.mul(f7, t34)  # t35: "cuda:0 f32[2, 2]"
    # t36 = prims.mul(t2, t35)  # t36: "cuda:0 f32[2, 2]"
    # t39 = prims.add(t33, t36)  # t39: "cuda:0 f32[2, 2]"
  return (t39,)

The mincut is (a_in, a_out), (t1_in, t1_out) with weight=2+1=3, the corresponding weight is:

a, 2.0
i7, 0.5
i8, 0.5
t5, 4.0
t0, 4.0
t1, 1.0
t2, 4.0
t3, 4.0
t4, 4.0
t5, 4.0
t33, 4.0
t34, 4.0
t35, 4.0
t36, 4.0
t39, 4.0

we expect to have the mincut ((a_in, a_out), (i7_in, i7_out), (i8_in, i8_out)) which actually has the same weight=0.5+0.5+2=3

The proposal of fixing it would be add a factor(e.g. 0.1) to reduce the weight for NumberProxy

cc: @IvanYashchuk

kiya00 avatar May 16 '24 13:05 kiya00

The PR description has only the "before" trace. Could you please update it adding how the trace looks like with this PR?

Trace after this PR(based on branch uniform_rng):

@torch.no_grad()
@no_autocast
def augmented_forward_fn(a):
  # a: "cuda:0 f32[2, 2]"
  t6 = get_rng_state_prim_impl(None, devices.Device("cuda:0"))  # t6: "cpu ui8[16]"
  (i7, i8) = unpack_rng_state_prim_impl(t6)
  del t6
  [t5] = nvFusion0(a, i7, i8)
    # t0 = prims.uniform_philox((2, 2), 0.0, 1.0, device=devices.Device("cuda:0"), dtype=dtypes.float32, seed=i7, offset=i8)  # t0: "cuda:0 f32[2, 2]"
    # t1 = prims.lt(t0, 0.5)  # t1: "cuda:0 b8[2, 2]"
    # t2 = prims.convert_element_type(t1, dtypes.float32)  # t2: "cuda:0 f32[2, 2]"
    # t3 = prims.mul(a, t2)  # t3: "cuda:0 f32[2, 2]"
    # t4 = prims.mul(t3, 2.0)  # t4: "cuda:0 f32[2, 2]"
    # t5 = prims.mul(a, t4)  # t5: "cuda:0 f32[2, 2]"
  t10 = update_rng_state_prim_impl(i7, i8)  # t10: "cpu ui8[16]"
  del i7, i8
  set_rng_state_prim_impl(t10, devices.Device("cuda:0"))
  del t10
  return {'output': t5, 'flat_args': [a], 'flat_output': (t5,)}, ((a,), (2.0, 0, 0))
# Constructed by Delete Last Used (took 0 milliseconds)
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def backward_fn(saved_for_backward, cotangents):
  # saved_for_backward: "Collection"
  # cotangents: "Collection"
  C0, C1, = saved_for_backward
  clear_collection(saved_for_backward)
  del saved_for_backward
  t6, = cotangents
  clear_collection(cotangents)
  del cotangents
  a, = C0
  clear_collection(C0)
  del C0
  f7, i7, i8, = C1
  clear_collection(C1)
  del C1
  [t39] = nvFusion0(a, f7, i7, i8, t6)
    # t0 = prims.uniform_philox((2, 2), 0.0, 1.0, device=devices.Device("cuda:0"), dtype=dtypes.float32, seed=i7, offset=i8)  # t0: "cuda:0 f32[2, 2]"
    # t1 = prims.lt(t0, 0.5)  # t1: "cuda:0 b8[2, 2]"
    # t2 = prims.convert_element_type(t1, dtypes.float32)  # t2: "cuda:0 f32[2, 2]"
    # t3 = prims.mul(a, t2)  # t3: "cuda:0 f32[2, 2]"
    # t4 = prims.mul(t3, 2.0)  # t4: "cuda:0 f32[2, 2]"
    # t33 = prims.mul(t4, t6)  # t33: "cuda:0 f32[2, 2]"
    # t34 = prims.mul(a, t6)  # t34: "cuda:0 f32[2, 2]"
    # t35 = prims.mul(f7, t34)  # t35: "cuda:0 f32[2, 2]"
    # t36 = prims.mul(t2, t35)  # t36: "cuda:0 f32[2, 2]"
    # t39 = prims.add(t33, t36)  # t39: "cuda:0 f32[2, 2]"
  del a, f7, i7, i8, t6
  return (t39,)

Not blocking the merge due to an absence of tests, but maybe you could come up with a short test for this change? The rematerialization tests currently live in tests/test_nvfuser_remat.py.

sure, let me try if I can construct a graph corresponding to the case

By the way, the final trace is the same on Llama-2-7b-hf fsdp_zero2_none_bucket before and after this pr @IvanYashchuk

kiya00 avatar May 17 '24 11:05 kiya00

By the way, the final trace is the same on Llama-2-7b-hf fsdp_zero2_none_bucket before and after this pr

Thank you for checking! This is very good and what I wanted to see.

IvanYashchuk avatar May 17 '24 14:05 IvanYashchuk

Hi @mruberry , could you take a look for review

kiya00 avatar May 21 '24 14:05 kiya00

Hi @nikitaved @t-vi , could you help to take a look for merging

kiya00 avatar May 24 '24 07:05 kiya00