Remove CPU overheads of torch.cuda.get_device_properties() by caching it
Description
In TE-Pytorch, built a pybind of sm_arch to cache device compute capability, thereby removing the CPU overheads of multiple calls to torch.cuda.get_device_properties().
Also fixed batch_p2p_comm check by making it aware of device compute capability and cp size.
Type of change
- [ ] Documentation change (change only to the documentation, either a fix or a new content)
- [x] Bug fix (non-breaking change which fixes an issue)
- [ ] New feature (non-breaking change which adds functionality)
- [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected)
- [ ] Infra/Build change
- [x] Code refactoring
Checklist:
- [x] I have read and followed the contributing guidelines
- [x] The functionality is complete
- [ ] I have commented my code, particularly in hard-to-understand areas
- [ ] I have made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my feature works
- [x] New and existing unit tests pass locally with my changes
/te-ci pytorch L1
/te-ci pytorch L1
/te-ci pytorch L1
Couldn't we accomplish something similar by wrapping
get_device_compute_capabilitywithfunctools.lru_cache? If we want it to respecttorch.cuda.current_device, we could make a helper function:@functools.lru_cache def _get_device_compute_capability(device: torch.device) -> Tuple[int, int]: props = torch.cuda.get_device_properties(device) return (props.major, props.minor) def get_device_compute_capability() -> Tuple[int, int]: """CUDA compute capability of current GPU""" return _get_device_compute_capability(torch.cuda.current_device())This wouldn't require changes outside of
transformer_engine/pytorch.utils.py.
This could work, but I think it's a cleaner fix by leveraging the existing functionality of TE. Actually this is also a fix which is consistent to TE-JAX, it has been done at here.
/te-ci pytorch L1
/te-ci pytorch L1