apex icon indicating copy to clipboard operation
apex copied to clipboard

Move to the correct device for v1 state dict

Open acphile opened this issue 1 year ago • 0 comments

This PRs aims to move the attributes of DistributedFusedAdam to the correct device for v1 state dict.

After loading V1 state dict, tensors in DistributedFusedAdam.["buckets"] will be on CPU device. In NeMo, the optimizer state would be moved to the target CUDA device by the _optimizer_to_device in pytorch_lightning. However, it fails to do what it meant to do for DistributedFusedAdam because tensors like DistributedFusedAdam.["buckets"][0].param_remainders_shard would not be moved to the correct device when using the V1 format.

This PR aims to fix it.

acphile avatar Mar 15 '24 23:03 acphile