Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Force use of torch.compile on deterministic roi_align implementation #8436

Merged
merged 8 commits into from
May 29, 2024

Conversation

ezyang
Copy link
Contributor

@ezyang ezyang commented May 21, 2024

Fixes #8168

Signed-off-by: Edward Z. Yang [email protected]

Copy link

pytorch-bot bot commented May 21, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/vision/8436

Note: Links to docs will display an error until the docs builds have been completed.

❌ 12 New Failures, 1 Unrelated Failure

As of commit ee25749 with merge base 775dd2d (image):

NEW FAILURES - The following jobs have failed:

FLAKY - The following job failed but was likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@ezyang
Copy link
Contributor Author

ezyang commented May 21, 2024

cc @qqaatw, I removed the MPS knob because of how memory hungry the eager implementation is, I doubt torch.compile works on MPS.

ezyang and others added 5 commits May 21, 2024 08:23
Signed-off-by: Edward Z. Yang <[email protected]>
Signed-off-by: Edward Z. Yang <[email protected]>
Signed-off-by: Edward Z. Yang <[email protected]>
Signed-off-by: Edward Z. Yang <[email protected]>
Copy link
Member

@NicolasHug NicolasHug left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @ezyang , 2 questions below but LGTM anyway. Unfortunately the MPS-related tests are all toasted (#8433), it's not related to this PR.

def lazy_compile(**compile_kwargs):
"""Lazily wrap a function with torch.compile on the first call

This avoids eagerly importing dynamo.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Am I understanding this correctly?

Suggested change
This avoids eagerly importing dynamo.
This avoids eagerly compiling a function at import time.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nope. Even with torch.compile at top level it isn't compiled until you call it the first time. But importing dynamo has undesirable side effects for eager mode only users so it's best not to do it.

@@ -232,7 +250,9 @@ def roi_align(
if not isinstance(rois, torch.Tensor):
rois = convert_boxes_to_roi_format(rois)
if not torch.jit.is_scripting():
if not _has_ops() or (torch.are_deterministic_algorithms_enabled() and (input.is_cuda or input.is_mps)):
if (
not _has_ops() or (torch.are_deterministic_algorithms_enabled() and (input.is_cuda or input.is_mps))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we just remove the mps part here since you mentioned MPS doesn't even work with torch.compile?

Suggested change
not _has_ops() or (torch.are_deterministic_algorithms_enabled() and (input.is_cuda or input.is_mps))
not _has_ops() or (torch.are_deterministic_algorithms_enabled() and input.is_cuda)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I opted to keep it around, because it was explicitly added by @qqaatw, but I don't really mind either way

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the late reply! I'm ok with either way that is best for the development. From the mentioned issue it seems only relevant to CUDA, is MPS similarly memory hungry with deterministic algorithm?

@NicolasHug NicolasHug merged commit a5f531a into pytorch:main May 29, 2024
51 of 64 checks passed
facebook-github-bot pushed a commit that referenced this pull request Jun 7, 2024
…entation (#8436)

Summary: Signed-off-by: Edward Z. Yang <[email protected]>

Reviewed By: vmoens

Differential Revision: D58283855

fbshipit-source-id: 914a91877c193b38f29af450a5935dd1ab5b20d7

Co-authored-by: Nicolas Hug <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

OOM Error with roi_align in PyTorch 2.1.1 but fine in PyTorch 2.0.1
4 participants