Introduce distributed embeddings
Part of https://github.com/NVIDIA-Merlin/Merlin/issues/733.
Goals :soccer:
There is a package called distributed-embeddings, a library for building large embedding based (e.g. recommender) models in Tensorflow. It's an alternative approach to SOK.
This PR introduces DistributedEmbedding for multi-GPU embedding table support.
Implementation Details :construction:
-
distributed-embeddingsby default will round-robin the entire embedding tables across the GPUs, e.g., the first embedding table on GPU 1, the second one on GPU 2, etc. - In theory the tables can be sharded by using
column_slicebut this has not been tested thoroughly from Models side. - Most of the logic is for inferring the embedding size from the schema using the cardinality in
int_domain(similarly to the existingEmbeddingTable), determining shapes, and translating a dictionary input into an ordered list input (becausedistributed-embeddingsdoesn't support dictionaries yet). - From the user perspective, they can replace
mm.Embeddingswithmm.DistributedEmbeddingsin their models when they wish to use multi-GPU embedding tables. (See the unit test for DLRM.) - Depends on upstream fix: https://github.com/NVIDIA-Merlin/distributed-embeddings/pull/6
- Added a Github actions for running unit tests that depend on horovod.
distributed-embeddingsis for now installed via a script that clones the github repo and installs from source, because there is no pypi package.
Testing Details :mag:
Unit tests: tests/unit/tf/horovod/test_embedding.py
Performance tests: TBD
The distributed embedding examples uses a custom train step functions: https://github.com/NVIDIA-Merlin/distributed-embeddings/blob/main/examples/dlrm/main.py#L201-L215
In my understanding, distributed embedding does NOT work with keras model.fit function: https://github.com/NVIDIA-Merlin/models/pull/974/files#diff-1e42e5c4771f01c26b3c78c545eb341590a4406b2c5af8da0491ab4b7ea51464R80
I think we need the distributed embedding team to review the PR
The PR needs to be update based on the dataloader changes. There is a new version of DE. we need to add an integration test as well to be sure that the functionality is working.
@FDecaYed hello. do you have any updates for this PR? thanks.
@rnyak Sorry, this fell off my list. As of now, DE already added support for model fit. So some of the problems should be gone. I'm willing to jump in and help if needed.
On the other hand, I'm not familiar with merlin models and the dataloader change you mentioned. @edknv do you know what it is and could you help bring the code up to date?