Add composability for ShardedEBC and FusedShardedEBC
Summary: The current EmbeddingBagCollection/FusedEmbeddingBagCollection are only usable through the DistributedModelParallel wrapper which override common torch.nn.module APIs (named_parameters/state_dict) etc. However, this makes these modules inflexible, and sometimes unusable without using DMP. We try to solve this issue by using the trick of registering torch.nn.module state on top of empty modules, and letting Pytorch's native calls handle the generation of state_dict/named_parameters, as per https://github.com/pytorch/torchrec/issues/528.
In addition, the current Sharded modules do not have ideal semantics of named_parameters() in that in a rank, it only returns parameters that they have shards for. One implication if this is that the keys of named_parameters() is different per rank. We solve this by creating the ShardedTensor parameters in a SPMD fashion using ShardedTensor.init_from_local_shards.
Relatedly, the current EmbeddingBagCollection's named_parameters() DOES NOT contain parameters that have their gradients updated via optimizer fusion. However, we should abide by the definition of named_parameters() and ensure that it truly does return all parameters. The torch.nn.module API calls should look near identical between the unsharded and sharded version.
Another point of difference is that the parameters() of tables sharded by data_parallel rely on DistributedModelParallel's DistributedDataParallel wrapper to register reduction hooks on their grads. Here we do this explicitly so that it is truly composable with other DDP wrappers. One implication of the previous solution is that it means that if an FSDP wrapper is passed to DistributedModelParallel, then tables registered as data_parallel will actually be FSDP parallel.
Due to the plethora of "small" changes listed above, it poses a big risk to the backwards compatibility of DistributedModelParallel, and so we are developing these new ShardedModels, with modular/native composability properties in a separate code path as per https://docs.google.com/document/d/15r5XxMTeVC90kA0u8t006-KLLeh9MZmL9DqX_F4VLxU/edit#. Since we are creating a new API with these new modules, we will also be enforcing sharding and compute_kernel/fusion separation at the user API level. The fusion EBC module's changes are done at the modular_fused_embeddingbag.py level.
Differential Revision: D38190302
This pull request was exported from Phabricator. Differential Revision: D38190302
This pull request was exported from Phabricator. Differential Revision: D38190302
This pull request was exported from Phabricator. Differential Revision: D38190302
This pull request was exported from Phabricator. Differential Revision: D38190302
This pull request was exported from Phabricator. Differential Revision: D38190302
This pull request was exported from Phabricator. Differential Revision: D38190302
This pull request was exported from Phabricator. Differential Revision: D38190302
This pull request was exported from Phabricator. Differential Revision: D38190302
This pull request was exported from Phabricator. Differential Revision: D38190302
This pull request has been reverted by b6b3466dfc05426948ec14f34129dafd57bd53d9.