dataloader icon indicating copy to clipboard operation
dataloader copied to clipboard

[BUG] Unable to extract session embeddings from a session-based transformer model

Open rnyak opened this issue 2 years ago • 3 comments

Bug description

I am trying to extract embedding but the following options do not work.

Option 1:

I tried these scripts but none works:

model_transformer.query_embeddings(train, index='session_id')

or 

model_transformer.query_embeddings(train, batch_size = 1024,  index='session_id')

Option 2:

I am able to generate session embeddings for a single batch but it does not work if I iterate over the loader batch by batch, it crashes.

this works: model_transformer.query_encoder(batch[0])

but iterating over loader batch by batch does not work:

all_sess_embeddings = []
for batch, _ in iter(loader):
    embds = model_transformer.query_encoder(batch).numpy()
    del batch
    gc.collect()
    all_sess_embeddings.append(embds)

Steps/Code to reproduce bug

Please go to this link to download the gist for the code to repro the issue:

https://gist.github.com/rnyak/d70822084c26ba6972615512e8a78bb2

Expected behavior

We should be able to extract session embeddings from query_model of the transformer model without any issues.

Environment details

  • Merlin version:
  • Platform:
  • Python version:
  • PyTorch version (GPU?):
  • Tensorflow version (GPU?): Using tensorflow 23.06 image with the latest branches pulled.

rnyak avatar Aug 17 '23 20:08 rnyak

issue is still open, was not resolved yet.

rnyak avatar Sep 18 '23 15:09 rnyak

the workaround solution would be something like:

batches = [{k:tf.constant(v.numpy()) for k, v in batch[0].items()} for batch in loader]
all_sess_embeddings = []
for batch in batches:
    embds = model_transformer.query_encoder(batch).numpy()
    del batch
    gc.collect()
    all_sess_embeddings.append(embds)

rnyak avatar Oct 10 '23 17:10 rnyak

This should be fixed in the dataloader.

rnyak avatar Oct 31 '23 14:10 rnyak