lightning-thunder
lightning-thunder copied to clipboard
Support batch_norm backward and nvfuser (PR2278)
Support batch norm:
-
Add backward support
-
Add nvFuser support
-
Type promotion of input/weight/bias are handled in thunder
-
Running stats are left as-is, and let nvfuserex/torchex to take care of in-place updating
nvfuser backend
# Constructed by Delete Last Used (took 0 milliseconds)
import torch
from thunder.executors.torchex import no_autocast
@torch.no_grad()
@no_autocast()
def augmented_forward_fn(*args):
# args: "Collection"
t0, \
t1, \
t2, \
t3, \
= args
del args
[t7, t8, t9] = nvFusion0(t0, t1, t2, t3)
# t4 = prims.convert_element_type(t0, dtypes.float32) # t4: "cuda:0 f32[3, 3, 3]"
# t5 = prims.convert_element_type(t3, dtypes.float32) # t5: "cuda:0 f32[3]"
# (t6, t7, t8) = prims.batch_norm(t4, t5, None, t1, t2, True, 0.2, 1e-05)
# t9 = prims.convert_element_type(t6, dtypes.float16) # t9: "cuda:0 f16[3, 3, 3]"
return {'output': t9, 'flat_args': [t0, t1, t2, t3], 'flat_output': (t9,)}, ((t0, t3, t7, t8), (True, True, True, False, 1e-05))]
# Constructed by Delete Last Used (took 0 milliseconds)
import torch
from thunder.executors.torchex import no_autocast
@torch.no_grad()
@no_autocast()
def backward_fn(saved_for_backward, cotangents):
# saved_for_backward: "Collection"
# cotangents: "Collection"
C0, \
C1, \
= saved_for_backward
clear_collection(saved_for_backward)
del saved_for_backward
t10, \
= cotangents
clear_collection(cotangents)
del cotangents
t0, \
t3, \
t7, \
t8, \
= C0
clear_collection(C0)
del C0
b5, \
b7, \
b8, \
b9, \
f6, \
= C1
clear_collection(C1)
del C1
[t61] = nvFusion0(b5, b7, b8, b9, f6, t0, t10, t3, t7, t8)
# t4 = prims.convert_element_type(t0, dtypes.float32) # t4: "cuda:0 f32[3, 3, 3]"
# t5 = prims.convert_element_type(t3, dtypes.float32) # t5: "cuda:0 f32[3]"
# t31 = prims.convert_element_type(t10, dtypes.float32) # t31: "cuda:0 f32[3, 3, 3]"
# t34 = prims.convert_element_type(t7, dtypes.float32) # t34: "cuda:0 f32[3]"
# t35 = prims.convert_element_type(t8, dtypes.float32) # t35: "cuda:0 f32[3]"
# t36 = prims.reshape(t34, (1, 3, 1)) # t36: "cuda:0 f32[1, 3, 1]"
# t37 = prims.sum(t31, (0, 2)) # t37: "cuda:0 f32[3]"
# t38 = prims.broadcast_in_dim(t36, (3, 3, 3), (0, 1, 2)) # t38: "cuda:0 f32[3, 3, 3]"
# t39 = prims.sub(t4, t38) # t39: "cuda:0 f32[3, 3, 3]"
# t40 = prims.mul(t31, t39) # t40: "cuda:0 f32[3, 3, 3]"
# t41 = prims.sum(t40, (0, 2)) # t41: "cuda:0 f32[3]"
# t42 = prims.mul(t37, 0.1111111111111111) # t42: "cuda:0 f32[3]"
# t43 = prims.reshape(t42, (1, 3, 1)) # t43: "cuda:0 f32[1, 3, 1]"
# t44 = prims.mul(t41, 0.1111111111111111) # t44: "cuda:0 f32[3]"
# t45 = prims.mul(t35, t35) # t45: "cuda:0 f32[3]"
# t46 = prims.mul(t44, t45) # t46: "cuda:0 f32[3]"
# t47 = prims.reshape(t46, (1, 3, 1)) # t47: "cuda:0 f32[1, 3, 1]"
# t48 = prims.mul(t35, t5) # t48: "cuda:0 f32[3]"
# t49 = prims.reshape(t48, (1, 3, 1)) # t49: "cuda:0 f32[1, 3, 1]"
# t52 = prims.broadcast_in_dim(t47, (3, 3, 3), (0, 1, 2)) # t52: "cuda:0 f32[3, 3, 3]"
# t53 = prims.mul(t39, t52) # t53: "cuda:0 f32[3, 3, 3]"
# t54 = prims.sub(t31, t53) # t54: "cuda:0 f32[3, 3, 3]"
# t55 = prims.broadcast_in_dim(t43, (3, 3, 3), (0, 1, 2)) # t55: "cuda:0 f32[3, 3, 3]"
# t56 = prims.sub(t54, t55) # t56: "cuda:0 f32[3, 3, 3]"
# t57 = prims.broadcast_in_dim(t49, (3, 3, 3), (0, 1, 2)) # t57: "cuda:0 f32[3, 3, 3]"
# t58 = prims.mul(t56, t57) # t58: "cuda:0 f32[3, 3, 3]"
# t61 = prims.convert_element_type(t58, dtypes.float16) # t61: "cuda:0 f16[3, 3, 3]"
del b5, b7, b8, b9, f6, t0, t10, t3, t7, t8
return (t61, None, None, None)
torch backend
# Constructed by Delete Last Used (took 0 milliseconds)
import thunder.core.dtypes as dtypes
from torch import Tensor
import torch
from thunder.executors.torchex import no_autocast
@torch.no_grad()
@no_autocast()
def augmented_forward_fn(*args):
# args: "Collection"
t0, \
t1, \
t2, \
t3, \
= args
del args
t4 = Tensor.to(t0, torch.float32, copy=True) # t4: "cuda:0 f32[3, 3, 3]"
# t4 = ltorch.to(t0, torch.float32, None, device=None, dtype=None, copy=True) # t4: "cuda:0 f32[3, 3, 3]"
# t4 = prims.convert_element_type(t0, dtypes.float32) # t4: "cuda:0 f32[3, 3, 3]"
t5 = Tensor.to(t3, torch.float32, copy=True) # t5: "cuda:0 f32[3]"
# t5 = ltorch.to(t3, torch.float32, None, device=None, dtype=None, copy=True) # t5: "cuda:0 f32[3]"
# t5 = prims.convert_element_type(t3, dtypes.float32) # t5: "cuda:0 f32[3]"
(t6, t7, t8) = torch.torch.ops.aten.native_batch_norm(t4, t5, None, t1, t2, True, 0.2, 1e-05)
# (t6, t7, t8) = prims.batch_norm(t4, t5, None, t1, t2, True, 0.2, 1e-05)
t9 = Tensor.to(t6, torch.float16, copy=True) # t9: "cuda:0 f16[3, 3, 3]"
# t9 = ltorch.to(t6, torch.float16, None, device=None, dtype=None, copy=True) # t9: "cuda:0 f16[3, 3, 3]"
# t9 = prims.convert_element_type(t6, dtypes.float16) # t9: "cuda:0 f16[3, 3, 3]"
del t6
return {'output': t9, 'flat_args': [t0, t1, t2, t3], 'flat_output': (t9,)}, ((t1, t2, t4, t5, t7, t8), (True, True, True, False, 1e-05))]
# Constructed by Delete Last Used (took 0 milliseconds)
from torch import Tensor
import torch
from thunder.executors.torchex import no_autocast
@torch.no_grad()
@no_autocast()
def backward_fn(saved_for_backward, cotangents):
# saved_for_backward: "Collection"
# cotangents: "Collection"
C0, \
C1, \
= saved_for_backward
clear_collection(saved_for_backward)
del saved_for_backward
t10, \
= cotangents
clear_collection(cotangents)
del cotangents
t1, \
t2, \
t4, \
t5, \
t7, \
t8, \
= C0
clear_collection(C0)
del C0
b5, \
b7, \
b8, \
b9, \
f6, \
= C1
clear_collection(C1)
del C1
t31 = Tensor.to(t10, torch.float32, copy=True) # t31: "cuda:0 f32[3, 3, 3]"
# t31 = ltorch.to(t10, torch.float32, None, device=None, dtype=None, copy=True) # t31: "cuda:0 f32[3, 3, 3]"
# t31 = prims.convert_element_type(t10, dtypes.float32) # t31: "cuda:0 f32[3, 3, 3]"
del t10
(t58, _, _) = torch.torch.ops.aten.native_batch_norm_backward(t31, t4, t5, t1, t2, t7, t8, b5, f6, [b7, b8, b9])
# (t58, t90, _) = ltorch.batch_norm_backward(t31, t4, t5, t1, t2, t7, t8, b5, f6, [b7, b8, b9])
# t63 = ltorch.to(t1, dtypes.float32, None, device=None, dtype=None, copy=False) # t63: "cuda:0 f32[3]"
# t63 = prims.convert_element_type(t1, dtypes.float32) # t63: "cuda:0 f32[3]"
# t64 = ltorch.to(t2, dtypes.float32, None, device=None, dtype=None, copy=False) # t64: "cuda:0 f32[3]"
# t64 = prims.convert_element_type(t2, dtypes.float32) # t64: "cuda:0 f32[3]"
# t65 = ltorch.to(t7, dtypes.float32, None, device=None, dtype=None, copy=False) # t65: "cuda:0 f32[3]"
# t65 = prims.convert_element_type(t7, dtypes.float32) # t65: "cuda:0 f32[3]"
# t66 = ltorch.to(t8, dtypes.float32, None, device=None, dtype=None, copy=False) # t66: "cuda:0 f32[3]"
# t66 = prims.convert_element_type(t8, dtypes.float32) # t66: "cuda:0 f32[3]"
# t67 = ltorch.reshape(t65, [1, 3, 1]) # t67: "cuda:0 f32[1, 3, 1]"
# t67 = prims.reshape(t65, (1, 3, 1)) # t67: "cuda:0 f32[1, 3, 1]"
# t68 = ltorch.sum(t31, [0, 2], False, dtype=None) # t68: "cuda:0 f32[3]"
# t68 = prims.sum(t31, (0, 2)) # t68: "cuda:0 f32[3]"
# t70 = ltorch.sub(t4, t67, alpha=None) # t70: "cuda:0 f32[3, 3, 3]"
# t69 = prims.broadcast_in_dim(t67, (3, 3, 3), (0, 1, 2)) # t69: "cuda:0 f32[3, 3, 3]"
# t70 = prims.sub(t4, t69) # t70: "cuda:0 f32[3, 3, 3]"
# t71 = ltorch.mul(t31, t70) # t71: "cuda:0 f32[3, 3, 3]"
# t71 = prims.mul(t31, t70) # t71: "cuda:0 f32[3, 3, 3]"
# t72 = ltorch.sum(t71, [0, 2], False, dtype=None) # t72: "cuda:0 f32[3]"
# t72 = prims.sum(t71, (0, 2)) # t72: "cuda:0 f32[3]"
# t73 = ltorch.mul(t68, 0.1111111111111111) # t73: "cuda:0 f32[3]"
# t73 = prims.mul(t68, 0.1111111111111111) # t73: "cuda:0 f32[3]"
# t74 = ltorch.reshape(t73, [1, 3, 1]) # t74: "cuda:0 f32[1, 3, 1]"
# t74 = prims.reshape(t73, (1, 3, 1)) # t74: "cuda:0 f32[1, 3, 1]"
# t75 = ltorch.mul(t72, 0.1111111111111111) # t75: "cuda:0 f32[3]"
# t75 = prims.mul(t72, 0.1111111111111111) # t75: "cuda:0 f32[3]"
# t76 = ltorch.mul(t66, t66) # t76: "cuda:0 f32[3]"
# t76 = prims.mul(t66, t66) # t76: "cuda:0 f32[3]"
# t77 = ltorch.mul(t75, t76) # t77: "cuda:0 f32[3]"
# t77 = prims.mul(t75, t76) # t77: "cuda:0 f32[3]"
# t78 = ltorch.reshape(t77, [1, 3, 1]) # t78: "cuda:0 f32[1, 3, 1]"
# t78 = prims.reshape(t77, (1, 3, 1)) # t78: "cuda:0 f32[1, 3, 1]"
# t79 = ltorch.mul(t66, t5) # t79: "cuda:0 f32[3]"
# t79 = prims.mul(t66, t5) # t79: "cuda:0 f32[3]"
# t80 = ltorch.reshape(t79, [1, 3, 1]) # t80: "cuda:0 f32[1, 3, 1]"
# t80 = prims.reshape(t79, (1, 3, 1)) # t80: "cuda:0 f32[1, 3, 1]"
# t82 = ltorch.sub(t4, t67, alpha=None) # t82: "cuda:0 f32[3, 3, 3]"
# t81 = prims.broadcast_in_dim(t67, (3, 3, 3), (0, 1, 2)) # t81: "cuda:0 f32[3, 3, 3]"
# t82 = prims.sub(t4, t81) # t82: "cuda:0 f32[3, 3, 3]"
# t84 = ltorch.mul(t82, t78) # t84: "cuda:0 f32[3, 3, 3]"
# t83 = prims.broadcast_in_dim(t78, (3, 3, 3), (0, 1, 2)) # t83: "cuda:0 f32[3, 3, 3]"
# t84 = prims.mul(t82, t83) # t84: "cuda:0 f32[3, 3, 3]"
# t85 = ltorch.sub(t31, t84, alpha=None) # t85: "cuda:0 f32[3, 3, 3]"
# t85 = prims.sub(t31, t84) # t85: "cuda:0 f32[3, 3, 3]"
# t87 = ltorch.sub(t85, t74, alpha=None) # t87: "cuda:0 f32[3, 3, 3]"
# t86 = prims.broadcast_in_dim(t74, (3, 3, 3), (0, 1, 2)) # t86: "cuda:0 f32[3, 3, 3]"
# t87 = prims.sub(t85, t86) # t87: "cuda:0 f32[3, 3, 3]"
# t58 = ltorch.mul(t87, t80) # t58: "cuda:0 f32[3, 3, 3]"
# t88 = prims.broadcast_in_dim(t80, (3, 3, 3), (0, 1, 2)) # t88: "cuda:0 f32[3, 3, 3]"
# t58 = prims.mul(t87, t88) # t58: "cuda:0 f32[3, 3, 3]"
# t90 = ltorch.mul(t72, t66) # t90: "cuda:0 f32[3]"
# t90 = prims.mul(t72, t66) # t90: "cuda:0 f32[3]"
del t31, t4, t5, t1, t2, t7, t8, b5, f6, b7, b8, b9
t61 = Tensor.to(t58, torch.float16, copy=True) # t61: "cuda:0 f16[3, 3, 3]"
# t61 = ltorch.to(t58, torch.float16, None, device=None, dtype=None, copy=True) # t61: "cuda:0 f16[3, 3, 3]"
# t61 = prims.convert_element_type(t58, dtypes.float16) # t61: "cuda:0 f16[3, 3, 3]"
del t58
return (t61, None, None, None)
Hi @IvanYashchuk @mruberry , I think it's ready to merge
Automerge is enabled but an approval from codeowners is required to actually merge. @t-vi could you take a look and merge please?