Reducing the weight of NumberProxy used in mincut in rematerialization
Before submitting
- [ ] Was this discussed/approved via a Github issue? (no need for typos and docs improvements)
- [ ] Did you read the contributor guideline, Pull Request section?
- [ ] Did you make sure to update the docs?
- [ ] Did you write any new necessary tests?
What does this PR do?
Fixes part of #114 .
Background:
When decomposing dropout to uniform_philox, the rematerialization doesn't pass the expected seed/offset to the backward
Trace before this PR(based on branch uniform_rng):
@torch.no_grad()
@no_autocast
def augmented_forward_fn(a):
# a: "cuda:0 f32[2, 2]"
t6 = get_rng_state_prim_impl(None, devices.Device("cuda:0")) # t6: "cpu ui8[16]"
(i7, i8) = unpack_rng_state_prim_impl(t6)
[t1, t5] = nvFusion0(a, i7, i8)
# t0 = prims.uniform_philox((2, 2), 0.0, 1.0, device=devices.Device("cuda:0"), dtype=dtypes.float32, seed=i7, offset=i8) # t0: "cuda:0 f32[2, 2]"
# t1 = prims.lt(t0, 0.5) # t1: "cuda:0 b8[2, 2]"
# t2 = prims.convert_element_type(t1, dtypes.float32) # t2: "cuda:0 f32[2, 2]"
# t3 = prims.mul(a, t2) # t3: "cuda:0 f32[2, 2]"
# t4 = prims.mul(t3, 2.0) # t4: "cuda:0 f32[2, 2]"
# t5 = prims.mul(a, t4) # t5: "cuda:0 f32[2, 2]"
t10 = update_rng_state_prim_impl(i7, i8) # t10: "cpu ui8[16]"
set_rng_state_prim_impl(t10, devices.Device("cuda:0"))
return {'output': t5, 'flat_args': [a], 'flat_output': (t5,)}, ((a, t1), (2.0,))
# Constructed by Update Call Context (took 0 milliseconds)
import torch
from thunder.executors.torchex import no_autocast
@torch.no_grad()
@no_autocast
def backward_fn(saved_for_backward, cotangents):
# saved_for_backward: "Collection"
# cotangents: "Collection"
C0, C1, = saved_for_backward
t6, = cotangents
a, t1, = C0
f7, = C1
[t39] = nvFusion0(a, f7, t1, t6)
# t2 = prims.convert_element_type(t1, dtypes.float32) # t2: "cuda:0 f32[2, 2]"
# t3 = prims.mul(a, t2) # t3: "cuda:0 f32[2, 2]"
# t4 = prims.mul(t3, 2.0) # t4: "cuda:0 f32[2, 2]"
# t33 = prims.mul(t4, t6) # t33: "cuda:0 f32[2, 2]"
# t34 = prims.mul(a, t6) # t34: "cuda:0 f32[2, 2]"
# t35 = prims.mul(f7, t34) # t35: "cuda:0 f32[2, 2]"
# t36 = prims.mul(t2, t35) # t36: "cuda:0 f32[2, 2]"
# t39 = prims.add(t33, t36) # t39: "cuda:0 f32[2, 2]"
return (t39,)
The mincut is (a_in, a_out), (t1_in, t1_out) with weight=2+1=3, the corresponding weight is:
a, 2.0
i7, 0.5
i8, 0.5
t5, 4.0
t0, 4.0
t1, 1.0
t2, 4.0
t3, 4.0
t4, 4.0
t5, 4.0
t33, 4.0
t34, 4.0
t35, 4.0
t36, 4.0
t39, 4.0
we expect to have the mincut ((a_in, a_out), (i7_in, i7_out), (i8_in, i8_out)) which actually has the same weight=0.5+0.5+2=3
The proposal of fixing it would be add a factor(e.g. 0.1) to reduce the weight for NumberProxy
cc: @IvanYashchuk
The PR description has only the "before" trace. Could you please update it adding how the trace looks like with this PR?
Trace after this PR(based on branch uniform_rng):
@torch.no_grad()
@no_autocast
def augmented_forward_fn(a):
# a: "cuda:0 f32[2, 2]"
t6 = get_rng_state_prim_impl(None, devices.Device("cuda:0")) # t6: "cpu ui8[16]"
(i7, i8) = unpack_rng_state_prim_impl(t6)
del t6
[t5] = nvFusion0(a, i7, i8)
# t0 = prims.uniform_philox((2, 2), 0.0, 1.0, device=devices.Device("cuda:0"), dtype=dtypes.float32, seed=i7, offset=i8) # t0: "cuda:0 f32[2, 2]"
# t1 = prims.lt(t0, 0.5) # t1: "cuda:0 b8[2, 2]"
# t2 = prims.convert_element_type(t1, dtypes.float32) # t2: "cuda:0 f32[2, 2]"
# t3 = prims.mul(a, t2) # t3: "cuda:0 f32[2, 2]"
# t4 = prims.mul(t3, 2.0) # t4: "cuda:0 f32[2, 2]"
# t5 = prims.mul(a, t4) # t5: "cuda:0 f32[2, 2]"
t10 = update_rng_state_prim_impl(i7, i8) # t10: "cpu ui8[16]"
del i7, i8
set_rng_state_prim_impl(t10, devices.Device("cuda:0"))
del t10
return {'output': t5, 'flat_args': [a], 'flat_output': (t5,)}, ((a,), (2.0, 0, 0))
# Constructed by Delete Last Used (took 0 milliseconds)
import torch
from thunder.executors.torchex import no_autocast
@torch.no_grad()
@no_autocast
def backward_fn(saved_for_backward, cotangents):
# saved_for_backward: "Collection"
# cotangents: "Collection"
C0, C1, = saved_for_backward
clear_collection(saved_for_backward)
del saved_for_backward
t6, = cotangents
clear_collection(cotangents)
del cotangents
a, = C0
clear_collection(C0)
del C0
f7, i7, i8, = C1
clear_collection(C1)
del C1
[t39] = nvFusion0(a, f7, i7, i8, t6)
# t0 = prims.uniform_philox((2, 2), 0.0, 1.0, device=devices.Device("cuda:0"), dtype=dtypes.float32, seed=i7, offset=i8) # t0: "cuda:0 f32[2, 2]"
# t1 = prims.lt(t0, 0.5) # t1: "cuda:0 b8[2, 2]"
# t2 = prims.convert_element_type(t1, dtypes.float32) # t2: "cuda:0 f32[2, 2]"
# t3 = prims.mul(a, t2) # t3: "cuda:0 f32[2, 2]"
# t4 = prims.mul(t3, 2.0) # t4: "cuda:0 f32[2, 2]"
# t33 = prims.mul(t4, t6) # t33: "cuda:0 f32[2, 2]"
# t34 = prims.mul(a, t6) # t34: "cuda:0 f32[2, 2]"
# t35 = prims.mul(f7, t34) # t35: "cuda:0 f32[2, 2]"
# t36 = prims.mul(t2, t35) # t36: "cuda:0 f32[2, 2]"
# t39 = prims.add(t33, t36) # t39: "cuda:0 f32[2, 2]"
del a, f7, i7, i8, t6
return (t39,)
Not blocking the merge due to an absence of tests, but maybe you could come up with a short test for this change? The rematerialization tests currently live in tests/test_nvfuser_remat.py.
sure, let me try if I can construct a graph corresponding to the case
By the way, the final trace is the same on Llama-2-7b-hf fsdp_zero2_none_bucket before and after this pr @IvanYashchuk
By the way, the final trace is the same on Llama-2-7b-hf fsdp_zero2_none_bucket before and after this pr
Thank you for checking! This is very good and what I wanted to see.
Hi @mruberry , could you take a look for review
Hi @nikitaved @t-vi , could you help to take a look for merging