[QST] How to get candidate features after using pre-trained embedding tables?
❓ Questions & Help
Details
I am following the tutorial hereto include pre-computed embeddings when I train a Two Tower Retrieval model. Specifically, I am using this method to not to include the Embedding Table as part of the model:
loader = mm.Loader(
train,
batch_size=1024,
transforms=[
EmbeddingOperator(
pretrained_movie_embs,
lookup_key="movieId",
embedding_name="pretrained_movie_embeddings",
),
],
)
I am trying to match this solution with the Retrieval Model tutorial here.
# Top-K evaluation
candidate_features = unique_rows_by_features(train, Tags.ITEM, Tags.ITEM_ID)
candidate_features.head()
topk = 20
topk_model = model.to_top_k_encoder(candidate_features, k=topk, batch_size=128)
# we can set `metrics` param in the `compile(), if we want
topk_model.compile(run_eagerly=False)
The problem is that loader.output_schema is different from loader.dataset.schema. The utility function unique_rows_by_features requires a dataset as the first argument, but passing loader.dataset doesn't work as this dataset doesn't contain the embedding vectors yet.
My question is, using the method to include pre-trained embeddings described above, how should one get the candidate_features, required by the Candidate Tower from the loader?
Thank you in advance if you take your time to answer!
@hkristof03 did you find a solution for this? I am running into the same problem.
@jhnealand in my current case the embedding table was small, so I made it part of the model as described here in the 1st example, then the standard evaluation works. I haven't tried to solve it if the embedding table is not part of the model.
I have managed to make the model agnostic whether the input is a dataset or a Loader.
I got rid of the unique_by_feature also.
Work with an object that can be a Dataset or a Loader and fetch the schema with a simple function like:
def get_schema_from_dataset_or_loader(X: Dataset | Loader | any):
if isinstance(X, Dataset):
return X.schema
if isinstance(X, Loader):
return X.output_schema
msg = f"There is not .schema attribute in {X}"
raise AttributeError(msg)
Then if you logically carry around objects like that you should find no problem in dealing with schemas.