TransformerEngine icon indicating copy to clipboard operation
TransformerEngine copied to clipboard

Remove CPU overheads of torch.cuda.get_device_properties() by caching it

Open xrennvidia opened this issue 9 months ago • 3 comments

Description

In TE-Pytorch, built a pybind of sm_arch to cache device compute capability, thereby removing the CPU overheads of multiple calls to torch.cuda.get_device_properties().

Also fixed batch_p2p_comm check by making it aware of device compute capability and cp size.

Type of change

  • [ ] Documentation change (change only to the documentation, either a fix or a new content)
  • [x] Bug fix (non-breaking change which fixes an issue)
  • [ ] New feature (non-breaking change which adds functionality)
  • [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • [ ] Infra/Build change
  • [x] Code refactoring

Checklist:

  • [x] I have read and followed the contributing guidelines
  • [x] The functionality is complete
  • [ ] I have commented my code, particularly in hard-to-understand areas
  • [ ] I have made corresponding changes to the documentation
  • [x] My changes generate no new warnings
  • [ ] I have added tests that prove my fix is effective or that my feature works
  • [x] New and existing unit tests pass locally with my changes

xrennvidia avatar Apr 26 '25 01:04 xrennvidia

/te-ci pytorch L1

xrennvidia avatar Apr 26 '25 01:04 xrennvidia

/te-ci pytorch L1

xrennvidia avatar Apr 26 '25 02:04 xrennvidia

/te-ci pytorch L1

xrennvidia avatar Apr 26 '25 09:04 xrennvidia

Couldn't we accomplish something similar by wrapping get_device_compute_capability with functools.lru_cache? If we want it to respect torch.cuda.current_device, we could make a helper function:

@functools.lru_cache
def _get_device_compute_capability(device: torch.device) -> Tuple[int, int]:
    props = torch.cuda.get_device_properties(device)
    return (props.major, props.minor)

def get_device_compute_capability() -> Tuple[int, int]:
    """CUDA compute capability of current GPU"""
    return _get_device_compute_capability(torch.cuda.current_device())

This wouldn't require changes outside of transformer_engine/pytorch.utils.py.

This could work, but I think it's a cleaner fix by leveraging the existing functionality of TE. Actually this is also a fix which is consistent to TE-JAX, it has been done at here.

xrennvidia avatar Apr 28 '25 18:04 xrennvidia

/te-ci pytorch L1

xrennvidia avatar Apr 28 '25 20:04 xrennvidia

/te-ci pytorch L1

xrennvidia avatar Apr 29 '25 01:04 xrennvidia