dgl icon indicating copy to clipboard operation
dgl copied to clipboard

[Feature] Add bfloat16 (bf16) support

Open yaox12 opened this issue 3 years ago • 11 comments

Description

Add bfloat16 support for CUDA >= 11.0.

  1. Add bf16 specializations for supported functions.
  2. Change the float type dispatcher from bits to real data types.
  3. Make PyTorch custom autograd modules work with both bf16 and fp16.
  4. Enable bf16 tests for segment_mm and gather_mm.

Note: -DUSE_FP16 is removed. FP16 is enabled by default while BF16 is enabled for CUDA >= 11.0, which follows the PyTorch convention.

Checklist

Please feel free to remove inapplicable items for your PR.

  • [x] The PR title starts with [$CATEGORY] (such as [NN], [Model], [Doc], [Feature]])
  • [x] Changes are complete (i.e. I finished coding on this PR)
  • [x] All changes have test coverage
  • [x] Code is well-documented
  • [x] To the best of my knowledge, examples are either not affected by this change, or have been fixed to be compatible with this change
  • [x] Related issue is referred in this PR

Examples

  1. Message-passing with bf16 is supported for CUDA >= 11.0 (even on older GPUs with SM < 80). For example, the following code works well on V100.
>>> import torch
>>> import dgl
>>> import dgl.function as fn
>>> dev = torch.device('cuda')
>>> g = dgl.rand_graph(30, 100).to(dev)  # Create a graph on GPU w/ 30 nodes and 100 edges.
>>> g.ndata['h'] = torch.rand(30, 16).to(dev, torch.bfloat16)  # Create bf16 node features.
>>> g.edata['w'] = torch.rand(100, 1).to(dev, torch.bfloat16)  # Create bf16 edge features.
>>> # Use DGL's built-in functions for message passing on bf16 features.
>>> g.update_all(fn.u_mul_e('h', 'w', 'm'), fn.sum('m', 'x'))
>>> g.ndata['x'].dtype
torch.bfloat16
>>> g.apply_edges(fn.u_dot_v('h', 'x', 'hx'))
>>> g.edata['hx'].dtype
torch.bfloat16
  1. AMP only works on GPUs with SM >= 80 and CUDA >= 11.0.

This also follows the PyTorch convention. See https://github.com/dmlc/dgl/issues/4333#issuecomment-1253277674.

yaox12 avatar Sep 27 '22 02:09 yaox12

To trigger regression tests:

  • @dgl-bot run [instance-type] [which tests] [compare-with-branch]; For example: @dgl-bot run g4dn.4xlarge all dmlc/master or @dgl-bot run c5.9xlarge kernel,api dmlc/master

dgl-bot avatar Sep 27 '22 02:09 dgl-bot

Commit ID: 213b27ce5e5aec17b70af3c5ce56e3d7abfbda0f

Build ID: 1

Status: ❌ CI test failed in Stage [Torch CPU (Win64) Unit test].

Report path: link

Full logs path: link

dgl-bot avatar Sep 27 '22 03:09 dgl-bot

Commit ID: 37ed78035f406940ed88a4b60cadd2b7cfde5fea

Build ID: 2

Status: ❌ CI test failed in Stage [Torch CPU (Win64) Unit test].

Report path: link

Full logs path: link

dgl-bot avatar Sep 27 '22 05:09 dgl-bot

Commit ID: e05448d0bc6e393a2a75845dbfc8242334702189

Build ID: 3

Status: ❌ CI test failed in Stage [Torch CPU (Win64) Unit test].

Report path: link

Full logs path: link

dgl-bot avatar Sep 27 '22 07:09 dgl-bot

Commit ID: eb43e9ea5d0254954d7b19fde1e02a0e86707359

Build ID: 4

Status: ✅ CI test succeeded

Report path: link

Full logs path: link

dgl-bot avatar Sep 27 '22 10:09 dgl-bot

Thanks for the great work, even with several comment, this PR still hold high standard code quality :) Another comment I would suggest you to do in the future is to avoid such a big PR, if possible, separate the refactoring and adding new lines of code such that it is easier to be reviewed.

frozenbugs avatar Sep 30 '22 04:09 frozenbugs

Have you done any regression test?

Not yet. It seems I don't have the authentication to trigger regression tests, and I failed to run it locally.

How is the performance compare to fp16?

For the AMP example in https://docs.dgl.ai/en/0.9.x/guide/mixed_precision.html, bf16 has a similar performance as fp16, both slightly faster than fp32. An advantage of bf16 is that it supports pure-bf16 training without AMP. Thus the example in https://github.com/dmlc/dgl/pull/4262 could converge with bf16 but not fp16.

yaox12 avatar Oct 10 '22 09:10 yaox12

Commit ID: 797c6e6945df04c938435cbe2db4777366b9132d

Build ID: 5

Status: ✅ CI test succeeded

Report path: link

Full logs path: link

dgl-bot avatar Oct 10 '22 11:10 dgl-bot

Commit ID: e048c4188b3432f1f695753c780c00982ac41621

Build ID: 6

Status: ✅ CI test succeeded

Report path: link

Full logs path: link

dgl-bot avatar Oct 10 '22 11:10 dgl-bot

Commit ID: 03c24fb5d00395ce0740073138b20fa50bb1f71b

Build ID: 8

Status: ✅ CI test succeeded

Report path: link

Full logs path: link

dgl-bot avatar Oct 11 '22 08:10 dgl-bot

Commit ID: 5b64970958cfdc236925b8df5bcb0849387722a7

Build ID: 9

Status: ✅ CI test succeeded

Report path: link

Full logs path: link

dgl-bot avatar Oct 14 '22 10:10 dgl-bot

Commit ID: 34d910cf39ade63255e6f2a8529de90f592a332a

Build ID: 10

Status: ✅ CI test succeeded

Report path: link

Full logs path: link

dgl-bot avatar Nov 06 '22 11:11 dgl-bot