-
Notifications
You must be signed in to change notification settings - Fork 148
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add topk arg to return topk items and scores at inference step #678
Changes from 5 commits
66e7432
244d3ba
5ad7e11
c327e61
5b72ff0
4f86f05
2a4476a
684517e
d24a7d3
4bd80dc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
import torch | ||
|
||
import transformers4rec.torch as tr | ||
from transformers4rec.config import transformer as tconf | ||
|
||
|
||
def test_torchscript_with_topk(torch_yoochoose_like, yoochoose_schema): | ||
input_module = tr.TabularSequenceFeatures.from_schema( | ||
yoochoose_schema, | ||
max_sequence_length=20, | ||
d_output=64, | ||
masking="causal", | ||
) | ||
prediction_task = tr.NextItemPredictionTask(weight_tying=True) | ||
transformer_config = tconf.XLNetConfig.build( | ||
d_model=64, n_head=8, n_layer=2, total_seq_length=20 | ||
) | ||
model = transformer_config.to_torch_model(input_module, prediction_task) | ||
|
||
_ = model(torch_yoochoose_like, training=False) | ||
|
||
topk = 10 | ||
model.top_k = topk | ||
model.eval() | ||
|
||
traced_model = torch.jit.trace(model, torch_yoochoose_like, strict=False) | ||
|
||
assert isinstance(traced_model, torch.jit.TopLevelTracedModule) | ||
assert torch.allclose( | ||
model(torch_yoochoose_like)[0], traced_model(torch_yoochoose_like)[0], rtol=1e-02 | ||
) | ||
assert traced_model(torch_yoochoose_like)[0].shape[1] == topk |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -303,7 +303,9 @@ def build(self, body, input_size, device=None, inputs=None, task_block=None, pre | |
body, input_size, device=device, inputs=inputs, task_block=task_block, pre=pre | ||
) | ||
|
||
def forward(self, inputs: torch.Tensor, targets=None, training=False, testing=False, **kwargs): | ||
def forward( | ||
self, inputs: torch.Tensor, targets=None, training=False, testing=False, top_k=-1, **kwargs | ||
): | ||
if isinstance(inputs, (tuple, list)): | ||
inputs = inputs[0] | ||
x = inputs.float() | ||
|
@@ -342,7 +344,11 @@ def forward(self, inputs: torch.Tensor, targets=None, training=False, testing=Fa | |
# Compute predictions probs | ||
x, _ = self.pre(x) # type: ignore | ||
|
||
return x | ||
if top_k == -1: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does top_k need to be a numeric value? Would it work with with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You're right, it is just a convention. Setting it to None won't create any issue. |
||
return x | ||
else: | ||
preds_sorted_item_scores, preds_sorted_item_ids = torch.topk(x, k=top_k, dim=-1) | ||
return preds_sorted_item_scores, preds_sorted_item_ids | ||
|
||
def remove_pad_3d(self, inp_tensor, non_pad_mask): | ||
# inp_tensor: (n_batch x seqlen x emb_dim) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the
NextItemPredictionTask
is always something that is created by user code. An alternative to passing to the foward method could be acceptingtop_k
in the constructor__init__
method. That way we wouldn't need to have this condition hereThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point. adding it to the NextItemPredictionTask constructor might be easier (as I show below) but, the reason we did not add this to the constructor is bcs there is no turning back changing the topK value at the inf step, once the model is trained with top_k arg. So to make it flexible, we decided to add it in the forward method of NextItemPredictionTask and constructor of the Model class.
@sararb wanna add something else?