You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
test_dataset = test_dataset if test_dataset is not None else self.test_dataset
assert self.schema is not None, "schema is required to generate Test Dataloader"
return T4RecDataLoader.parse(self.args.data_loader_engine).from_schema(
self.schema,
self.test_dataset_or_path, ### <- This is incorrect?
self.args.per_device_eval_batch_size,
max_sequence_length=self.args.max_sequence_length,
drop_last=self.args.dataloader_drop_last,
shuffle=False,
shuffle_buffer_size=self.args.shuffle_buffer_size,
)
Thank you for catching this bug!
regarding your point: self.test_dataset_or_path, ### <- This is incorrect? --> This is not correct and the method should use test_dataset instead of self.test_dataset_or_path.
I opened a quick fix in #571
Bug description
trainer.predict(test_paths)
does not use the files parsed in the function to evaluate.predict
callsget_test_dataloader
predict
initialize the test dataloader fromself.test_dataset_or_path
see https://github.com/NVIDIA-Merlin/Transformers4Rec/blob/main/transformers4rec/torch/trainer.py#L204-L209
Work around could be
The text was updated successfully, but these errors were encountered: