Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Unable to extract session embeddings from a session-based transformer model #163

Open
rnyak opened this issue Aug 17, 2023 · 3 comments · May be fixed by #164
Open

[BUG] Unable to extract session embeddings from a session-based transformer model #163

rnyak opened this issue Aug 17, 2023 · 3 comments · May be fixed by #164
Assignees
Labels
bug Something isn't working P0
Milestone

Comments

@rnyak
Copy link

rnyak commented Aug 17, 2023

Bug description

I am trying to extract embedding but the following options do not work.

Option 1:

I tried these scripts but none works:

model_transformer.query_embeddings(train, index='session_id')

or 

model_transformer.query_embeddings(train, batch_size = 1024,  index='session_id')

Option 2:

I am able to generate session embeddings for a single batch but it does not work if I iterate over the loader batch by batch, it crashes.

this works:
model_transformer.query_encoder(batch[0])

but iterating over loader batch by batch does not work:

all_sess_embeddings = []
for batch, _ in iter(loader):
    embds = model_transformer.query_encoder(batch).numpy()
    del batch
    gc.collect()
    all_sess_embeddings.append(embds)

Steps/Code to reproduce bug

Please go to this link to download the gist for the code to repro the issue:

https://gist.github.com/rnyak/d70822084c26ba6972615512e8a78bb2

Expected behavior

We should be able to extract session embeddings from query_model of the transformer model without any issues.

Environment details

  • Merlin version:
  • Platform:
  • Python version:
  • PyTorch version (GPU?):
  • Tensorflow version (GPU?): Using tensorflow 23.06 image with the latest branches pulled.
@rnyak rnyak added bug Something isn't working P0 labels Aug 17, 2023
@rnyak rnyak added this to the Merlin 23.07 milestone Aug 17, 2023
@rnyak
Copy link
Author

rnyak commented Sep 18, 2023

issue is still open, was not resolved yet.

@rnyak
Copy link
Author

rnyak commented Oct 10, 2023

the workaround solution would be something like:

batches = [{k:tf.constant(v.numpy()) for k, v in batch[0].items()} for batch in loader]
all_sess_embeddings = []
for batch in batches:
    embds = model_transformer.query_encoder(batch).numpy()
    del batch
    gc.collect()
    all_sess_embeddings.append(embds)

@rnyak
Copy link
Author

rnyak commented Oct 31, 2023

This should be fixed in the dataloader.

@rnyak rnyak transferred this issue from NVIDIA-Merlin/models Oct 31, 2023
@rnyak rnyak assigned jperez999 and oliverholworthy and unassigned marcromeyn Oct 31, 2023
@jperez999 jperez999 linked a pull request Nov 1, 2023 that will close this issue
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working P0
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants