Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Transformer predict does not use path from parameters #570

Closed
bschifferer opened this issue Dec 8, 2022 · 1 comment · Fixed by #571
Closed

[BUG] Transformer predict does not use path from parameters #570

bschifferer opened this issue Dec 8, 2022 · 1 comment · Fixed by #571
Assignees
Labels
bug Something isn't working status/needs-triage

Comments

@bschifferer
Copy link
Contributor

Bug description

trainer.predict(test_paths) does not use the files parsed in the function to evaluate.

predict calls get_test_dataloader

trainer.predict(
    test_dataset: torch.utils.data.dataset.Dataset,
    ignore_keys: Union[List[str], NoneType] = None,
    metric_key_prefix: str = 'test',
) -> transformers.trainer_utils.PredictionOutput
        self._memory_tracker.start()
        test_dataloader = self.get_test_dataloader(test_dataset) 

predict initialize the test dataloader from self.test_dataset_or_path
see https://github.com/NVIDIA-Merlin/Transformers4Rec/blob/main/transformers4rec/torch/trainer.py#L204-L209

        test_dataset = test_dataset if test_dataset is not None else self.test_dataset
        assert self.schema is not None, "schema is required to generate Test Dataloader"
        return T4RecDataLoader.parse(self.args.data_loader_engine).from_schema(
            self.schema,
            self.test_dataset_or_path, ### <- This is incorrect?
            self.args.per_device_eval_batch_size,
            max_sequence_length=self.args.max_sequence_length,
            drop_last=self.args.dataloader_drop_last,
            shuffle=False,
            shuffle_buffer_size=self.args.shuffle_buffer_size,
        )

Work around could be

trainer.test_dataset_or_path = paths
prediction = trainer.predict(paths)
@sararb
Copy link
Contributor

sararb commented Dec 8, 2022

Thank you for catching this bug!
regarding your point: self.test_dataset_or_path, ### <- This is incorrect? --> This is not correct and the method should use test_dataset instead of self.test_dataset_or_path.
I opened a quick fix in #571

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working status/needs-triage
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants