[Feature] Add bfloat16 (bf16) support
Description
Add bfloat16 support for CUDA >= 11.0.
- Add bf16 specializations for supported functions.
- Change the float type dispatcher from bits to real data types.
- Make PyTorch custom autograd modules work with both bf16 and fp16.
- Enable bf16 tests for
segment_mmandgather_mm.
Note: -DUSE_FP16 is removed. FP16 is enabled by default while BF16 is enabled for CUDA >= 11.0, which follows the PyTorch convention.
Checklist
Please feel free to remove inapplicable items for your PR.
- [x] The PR title starts with [$CATEGORY] (such as [NN], [Model], [Doc], [Feature]])
- [x] Changes are complete (i.e. I finished coding on this PR)
- [x] All changes have test coverage
- [x] Code is well-documented
- [x] To the best of my knowledge, examples are either not affected by this change, or have been fixed to be compatible with this change
- [x] Related issue is referred in this PR
Examples
- Message-passing with bf16 is supported for CUDA >= 11.0 (even on older GPUs with SM < 80). For example, the following code works well on V100.
>>> import torch
>>> import dgl
>>> import dgl.function as fn
>>> dev = torch.device('cuda')
>>> g = dgl.rand_graph(30, 100).to(dev) # Create a graph on GPU w/ 30 nodes and 100 edges.
>>> g.ndata['h'] = torch.rand(30, 16).to(dev, torch.bfloat16) # Create bf16 node features.
>>> g.edata['w'] = torch.rand(100, 1).to(dev, torch.bfloat16) # Create bf16 edge features.
>>> # Use DGL's built-in functions for message passing on bf16 features.
>>> g.update_all(fn.u_mul_e('h', 'w', 'm'), fn.sum('m', 'x'))
>>> g.ndata['x'].dtype
torch.bfloat16
>>> g.apply_edges(fn.u_dot_v('h', 'x', 'hx'))
>>> g.edata['hx'].dtype
torch.bfloat16
- AMP only works on GPUs with SM >= 80 and CUDA >= 11.0.
This also follows the PyTorch convention. See https://github.com/dmlc/dgl/issues/4333#issuecomment-1253277674.
To trigger regression tests:
-
@dgl-bot run [instance-type] [which tests] [compare-with-branch]; For example:@dgl-bot run g4dn.4xlarge all dmlc/masteror@dgl-bot run c5.9xlarge kernel,api dmlc/master
Commit ID: 213b27ce5e5aec17b70af3c5ce56e3d7abfbda0f
Build ID: 1
Status: ❌ CI test failed in Stage [Torch CPU (Win64) Unit test].
Report path: link
Full logs path: link
Commit ID: 37ed78035f406940ed88a4b60cadd2b7cfde5fea
Build ID: 2
Status: ❌ CI test failed in Stage [Torch CPU (Win64) Unit test].
Report path: link
Full logs path: link
Commit ID: e05448d0bc6e393a2a75845dbfc8242334702189
Build ID: 3
Status: ❌ CI test failed in Stage [Torch CPU (Win64) Unit test].
Report path: link
Full logs path: link
Commit ID: eb43e9ea5d0254954d7b19fde1e02a0e86707359
Build ID: 4
Status: ✅ CI test succeeded
Report path: link
Full logs path: link
Thanks for the great work, even with several comment, this PR still hold high standard code quality :) Another comment I would suggest you to do in the future is to avoid such a big PR, if possible, separate the refactoring and adding new lines of code such that it is easier to be reviewed.
Have you done any regression test?
Not yet. It seems I don't have the authentication to trigger regression tests, and I failed to run it locally.
How is the performance compare to fp16?
For the AMP example in https://docs.dgl.ai/en/0.9.x/guide/mixed_precision.html, bf16 has a similar performance as fp16, both slightly faster than fp32. An advantage of bf16 is that it supports pure-bf16 training without AMP. Thus the example in https://github.com/dmlc/dgl/pull/4262 could converge with bf16 but not fp16.
Commit ID: 797c6e6945df04c938435cbe2db4777366b9132d
Build ID: 5
Status: ✅ CI test succeeded
Report path: link
Full logs path: link
Commit ID: e048c4188b3432f1f695753c780c00982ac41621
Build ID: 6
Status: ✅ CI test succeeded
Report path: link
Full logs path: link
Commit ID: 03c24fb5d00395ce0740073138b20fa50bb1f71b
Build ID: 8
Status: ✅ CI test succeeded
Report path: link
Full logs path: link