lightning-thunder icon indicating copy to clipboard operation
lightning-thunder copied to clipboard

Support batch_norm backward and nvfuser (PR2278)

Open kiya00 opened this issue 1 year ago • 1 comments

Support batch norm:

  • Add backward support

  • Add nvFuser support

  • Type promotion of input/weight/bias are handled in thunder

  • Running stats are left as-is, and let nvfuserex/torchex to take care of in-place updating

kiya00 avatar Mar 21 '24 09:03 kiya00

nvfuser backend
 # Constructed by Delete Last Used (took 0 milliseconds)
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast()
def augmented_forward_fn(*args):
  # args: "Collection"
  t0, \
  t1, \
  t2, \
  t3, \
  = args
  del args
  [t7, t8, t9] = nvFusion0(t0, t1, t2, t3)
    # t4 = prims.convert_element_type(t0, dtypes.float32)  # t4: "cuda:0 f32[3, 3, 3]"
    # t5 = prims.convert_element_type(t3, dtypes.float32)  # t5: "cuda:0 f32[3]"
    # (t6, t7, t8) = prims.batch_norm(t4, t5, None, t1, t2, True, 0.2, 1e-05)
    # t9 = prims.convert_element_type(t6, dtypes.float16)  # t9: "cuda:0 f16[3, 3, 3]"
  return {'output': t9, 'flat_args': [t0, t1, t2, t3], 'flat_output': (t9,)}, ((t0, t3, t7, t8), (True, True, True, False, 1e-05))]
# Constructed by Delete Last Used (took 0 milliseconds)
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast()
def backward_fn(saved_for_backward, cotangents):
  # saved_for_backward: "Collection"
  # cotangents: "Collection"
  C0, \
  C1, \
  = saved_for_backward
  clear_collection(saved_for_backward)
  del saved_for_backward
  t10, \
  = cotangents
  clear_collection(cotangents)
  del cotangents
  t0, \
  t3, \
  t7, \
  t8, \
  = C0
  clear_collection(C0)
  del C0
  b5, \
  b7, \
  b8, \
  b9, \
  f6, \
  = C1
  clear_collection(C1)
  del C1
  [t61] = nvFusion0(b5, b7, b8, b9, f6, t0, t10, t3, t7, t8)
    # t4 = prims.convert_element_type(t0, dtypes.float32)  # t4: "cuda:0 f32[3, 3, 3]"
    # t5 = prims.convert_element_type(t3, dtypes.float32)  # t5: "cuda:0 f32[3]"
    # t31 = prims.convert_element_type(t10, dtypes.float32)  # t31: "cuda:0 f32[3, 3, 3]"
    # t34 = prims.convert_element_type(t7, dtypes.float32)  # t34: "cuda:0 f32[3]"
    # t35 = prims.convert_element_type(t8, dtypes.float32)  # t35: "cuda:0 f32[3]"
    # t36 = prims.reshape(t34, (1, 3, 1))  # t36: "cuda:0 f32[1, 3, 1]"
    # t37 = prims.sum(t31, (0, 2))  # t37: "cuda:0 f32[3]"
    # t38 = prims.broadcast_in_dim(t36, (3, 3, 3), (0, 1, 2))  # t38: "cuda:0 f32[3, 3, 3]"
    # t39 = prims.sub(t4, t38)  # t39: "cuda:0 f32[3, 3, 3]"
    # t40 = prims.mul(t31, t39)  # t40: "cuda:0 f32[3, 3, 3]"
    # t41 = prims.sum(t40, (0, 2))  # t41: "cuda:0 f32[3]"
    # t42 = prims.mul(t37, 0.1111111111111111)  # t42: "cuda:0 f32[3]"
    # t43 = prims.reshape(t42, (1, 3, 1))  # t43: "cuda:0 f32[1, 3, 1]"
    # t44 = prims.mul(t41, 0.1111111111111111)  # t44: "cuda:0 f32[3]"
    # t45 = prims.mul(t35, t35)  # t45: "cuda:0 f32[3]"
    # t46 = prims.mul(t44, t45)  # t46: "cuda:0 f32[3]"
    # t47 = prims.reshape(t46, (1, 3, 1))  # t47: "cuda:0 f32[1, 3, 1]"
    # t48 = prims.mul(t35, t5)  # t48: "cuda:0 f32[3]"
    # t49 = prims.reshape(t48, (1, 3, 1))  # t49: "cuda:0 f32[1, 3, 1]"
    # t52 = prims.broadcast_in_dim(t47, (3, 3, 3), (0, 1, 2))  # t52: "cuda:0 f32[3, 3, 3]"
    # t53 = prims.mul(t39, t52)  # t53: "cuda:0 f32[3, 3, 3]"
    # t54 = prims.sub(t31, t53)  # t54: "cuda:0 f32[3, 3, 3]"
    # t55 = prims.broadcast_in_dim(t43, (3, 3, 3), (0, 1, 2))  # t55: "cuda:0 f32[3, 3, 3]"
    # t56 = prims.sub(t54, t55)  # t56: "cuda:0 f32[3, 3, 3]"
    # t57 = prims.broadcast_in_dim(t49, (3, 3, 3), (0, 1, 2))  # t57: "cuda:0 f32[3, 3, 3]"
    # t58 = prims.mul(t56, t57)  # t58: "cuda:0 f32[3, 3, 3]"
    # t61 = prims.convert_element_type(t58, dtypes.float16)  # t61: "cuda:0 f16[3, 3, 3]"
  del b5, b7, b8, b9, f6, t0, t10, t3, t7, t8
  return (t61, None, None, None)
torch backend
# Constructed by Delete Last Used (took 0 milliseconds)
import thunder.core.dtypes as dtypes
from torch import Tensor
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast()
def augmented_forward_fn(*args):
  # args: "Collection"
  t0, \
  t1, \
  t2, \
  t3, \
  = args
  del args
  t4 = Tensor.to(t0, torch.float32, copy=True)  # t4: "cuda:0 f32[3, 3, 3]"
    # t4 = ltorch.to(t0, torch.float32, None, device=None, dtype=None, copy=True)  # t4: "cuda:0 f32[3, 3, 3]"
      # t4 = prims.convert_element_type(t0, dtypes.float32)  # t4: "cuda:0 f32[3, 3, 3]"
  t5 = Tensor.to(t3, torch.float32, copy=True)  # t5: "cuda:0 f32[3]"
    # t5 = ltorch.to(t3, torch.float32, None, device=None, dtype=None, copy=True)  # t5: "cuda:0 f32[3]"
      # t5 = prims.convert_element_type(t3, dtypes.float32)  # t5: "cuda:0 f32[3]"
  (t6, t7, t8) = torch.torch.ops.aten.native_batch_norm(t4, t5, None, t1, t2, True, 0.2, 1e-05)
    # (t6, t7, t8) = prims.batch_norm(t4, t5, None, t1, t2, True, 0.2, 1e-05)
  t9 = Tensor.to(t6, torch.float16, copy=True)  # t9: "cuda:0 f16[3, 3, 3]"
    # t9 = ltorch.to(t6, torch.float16, None, device=None, dtype=None, copy=True)  # t9: "cuda:0 f16[3, 3, 3]"
      # t9 = prims.convert_element_type(t6, dtypes.float16)  # t9: "cuda:0 f16[3, 3, 3]"
  del t6
  return {'output': t9, 'flat_args': [t0, t1, t2, t3], 'flat_output': (t9,)}, ((t1, t2, t4, t5, t7, t8), (True, True, True, False, 1e-05))]
# Constructed by Delete Last Used (took 0 milliseconds)
from torch import Tensor
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast()
def backward_fn(saved_for_backward, cotangents):
  # saved_for_backward: "Collection"
  # cotangents: "Collection"
  C0, \
  C1, \
  = saved_for_backward
  clear_collection(saved_for_backward)
  del saved_for_backward
  t10, \
  = cotangents
  clear_collection(cotangents)
  del cotangents
  t1, \
  t2, \
  t4, \
  t5, \
  t7, \
  t8, \
  = C0
  clear_collection(C0)
  del C0
  b5, \
  b7, \
  b8, \
  b9, \
  f6, \
  = C1
  clear_collection(C1)
  del C1
  t31 = Tensor.to(t10, torch.float32, copy=True)  # t31: "cuda:0 f32[3, 3, 3]"
    # t31 = ltorch.to(t10, torch.float32, None, device=None, dtype=None, copy=True)  # t31: "cuda:0 f32[3, 3, 3]"
      # t31 = prims.convert_element_type(t10, dtypes.float32)  # t31: "cuda:0 f32[3, 3, 3]"
  del t10
  (t58, _, _) = torch.torch.ops.aten.native_batch_norm_backward(t31, t4, t5, t1, t2, t7, t8, b5, f6, [b7, b8, b9])
    # (t58, t90, _) = ltorch.batch_norm_backward(t31, t4, t5, t1, t2, t7, t8, b5, f6, [b7, b8, b9])
      # t63 = ltorch.to(t1, dtypes.float32, None, device=None, dtype=None, copy=False)  # t63: "cuda:0 f32[3]"
        # t63 = prims.convert_element_type(t1, dtypes.float32)  # t63: "cuda:0 f32[3]"
      # t64 = ltorch.to(t2, dtypes.float32, None, device=None, dtype=None, copy=False)  # t64: "cuda:0 f32[3]"
        # t64 = prims.convert_element_type(t2, dtypes.float32)  # t64: "cuda:0 f32[3]"
      # t65 = ltorch.to(t7, dtypes.float32, None, device=None, dtype=None, copy=False)  # t65: "cuda:0 f32[3]"
        # t65 = prims.convert_element_type(t7, dtypes.float32)  # t65: "cuda:0 f32[3]"
      # t66 = ltorch.to(t8, dtypes.float32, None, device=None, dtype=None, copy=False)  # t66: "cuda:0 f32[3]"
        # t66 = prims.convert_element_type(t8, dtypes.float32)  # t66: "cuda:0 f32[3]"
      # t67 = ltorch.reshape(t65, [1, 3, 1])  # t67: "cuda:0 f32[1, 3, 1]"
        # t67 = prims.reshape(t65, (1, 3, 1))  # t67: "cuda:0 f32[1, 3, 1]"
      # t68 = ltorch.sum(t31, [0, 2], False, dtype=None)  # t68: "cuda:0 f32[3]"
        # t68 = prims.sum(t31, (0, 2))  # t68: "cuda:0 f32[3]"
      # t70 = ltorch.sub(t4, t67, alpha=None)  # t70: "cuda:0 f32[3, 3, 3]"
        # t69 = prims.broadcast_in_dim(t67, (3, 3, 3), (0, 1, 2))  # t69: "cuda:0 f32[3, 3, 3]"
        # t70 = prims.sub(t4, t69)  # t70: "cuda:0 f32[3, 3, 3]"
      # t71 = ltorch.mul(t31, t70)  # t71: "cuda:0 f32[3, 3, 3]"
        # t71 = prims.mul(t31, t70)  # t71: "cuda:0 f32[3, 3, 3]"
      # t72 = ltorch.sum(t71, [0, 2], False, dtype=None)  # t72: "cuda:0 f32[3]"
        # t72 = prims.sum(t71, (0, 2))  # t72: "cuda:0 f32[3]"
      # t73 = ltorch.mul(t68, 0.1111111111111111)  # t73: "cuda:0 f32[3]"
        # t73 = prims.mul(t68, 0.1111111111111111)  # t73: "cuda:0 f32[3]"
      # t74 = ltorch.reshape(t73, [1, 3, 1])  # t74: "cuda:0 f32[1, 3, 1]"
        # t74 = prims.reshape(t73, (1, 3, 1))  # t74: "cuda:0 f32[1, 3, 1]"
      # t75 = ltorch.mul(t72, 0.1111111111111111)  # t75: "cuda:0 f32[3]"
        # t75 = prims.mul(t72, 0.1111111111111111)  # t75: "cuda:0 f32[3]"
      # t76 = ltorch.mul(t66, t66)  # t76: "cuda:0 f32[3]"
        # t76 = prims.mul(t66, t66)  # t76: "cuda:0 f32[3]"
      # t77 = ltorch.mul(t75, t76)  # t77: "cuda:0 f32[3]"
        # t77 = prims.mul(t75, t76)  # t77: "cuda:0 f32[3]"
      # t78 = ltorch.reshape(t77, [1, 3, 1])  # t78: "cuda:0 f32[1, 3, 1]"
        # t78 = prims.reshape(t77, (1, 3, 1))  # t78: "cuda:0 f32[1, 3, 1]"
      # t79 = ltorch.mul(t66, t5)  # t79: "cuda:0 f32[3]"
        # t79 = prims.mul(t66, t5)  # t79: "cuda:0 f32[3]"
      # t80 = ltorch.reshape(t79, [1, 3, 1])  # t80: "cuda:0 f32[1, 3, 1]"
        # t80 = prims.reshape(t79, (1, 3, 1))  # t80: "cuda:0 f32[1, 3, 1]"
      # t82 = ltorch.sub(t4, t67, alpha=None)  # t82: "cuda:0 f32[3, 3, 3]"
        # t81 = prims.broadcast_in_dim(t67, (3, 3, 3), (0, 1, 2))  # t81: "cuda:0 f32[3, 3, 3]"
        # t82 = prims.sub(t4, t81)  # t82: "cuda:0 f32[3, 3, 3]"
      # t84 = ltorch.mul(t82, t78)  # t84: "cuda:0 f32[3, 3, 3]"
        # t83 = prims.broadcast_in_dim(t78, (3, 3, 3), (0, 1, 2))  # t83: "cuda:0 f32[3, 3, 3]"
        # t84 = prims.mul(t82, t83)  # t84: "cuda:0 f32[3, 3, 3]"
      # t85 = ltorch.sub(t31, t84, alpha=None)  # t85: "cuda:0 f32[3, 3, 3]"
        # t85 = prims.sub(t31, t84)  # t85: "cuda:0 f32[3, 3, 3]"
      # t87 = ltorch.sub(t85, t74, alpha=None)  # t87: "cuda:0 f32[3, 3, 3]"
        # t86 = prims.broadcast_in_dim(t74, (3, 3, 3), (0, 1, 2))  # t86: "cuda:0 f32[3, 3, 3]"
        # t87 = prims.sub(t85, t86)  # t87: "cuda:0 f32[3, 3, 3]"
      # t58 = ltorch.mul(t87, t80)  # t58: "cuda:0 f32[3, 3, 3]"
        # t88 = prims.broadcast_in_dim(t80, (3, 3, 3), (0, 1, 2))  # t88: "cuda:0 f32[3, 3, 3]"
        # t58 = prims.mul(t87, t88)  # t58: "cuda:0 f32[3, 3, 3]"
      # t90 = ltorch.mul(t72, t66)  # t90: "cuda:0 f32[3]"
        # t90 = prims.mul(t72, t66)  # t90: "cuda:0 f32[3]"
  del t31, t4, t5, t1, t2, t7, t8, b5, f6, b7, b8, b9
  t61 = Tensor.to(t58, torch.float16, copy=True)  # t61: "cuda:0 f16[3, 3, 3]"
    # t61 = ltorch.to(t58, torch.float16, None, device=None, dtype=None, copy=True)  # t61: "cuda:0 f16[3, 3, 3]"
      # t61 = prims.convert_element_type(t58, dtypes.float16)  # t61: "cuda:0 f16[3, 3, 3]"
  del t58
  return (t61, None, None, None)

kiya00 avatar Mar 21 '24 09:03 kiya00

Hi @IvanYashchuk @mruberry , I think it's ready to merge

kiya00 avatar Apr 03 '24 09:04 kiya00

Automerge is enabled but an approval from codeowners is required to actually merge. @t-vi could you take a look and merge please?

IvanYashchuk avatar Apr 03 '24 09:04 IvanYashchuk