-
Notifications
You must be signed in to change notification settings - Fork 1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add warnings and fallback for unassigned devices in infer_auto_device_map #3066
Add warnings and fallback for unassigned devices in infer_auto_device_map #3066
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
…map" This reverts commit d607bfb.
The fallback allocation will be reintroduced once the branching logic is fully refactored. This commit prepares the function infer_auto_device_map for further refactoring.
Implemented fallback allocation to allow modules to be allocated to devices using BFS when regular allocation fails. This enhancement improves the allocation process by ensuring that at least one module is assigned to the device, even under tight memory constraints.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the updates and the style fix. I'm not very knowledgeable about the whole logic being applied here, so I won't comment on that.
Personally, I find the use of continue
in addition to many nested conditionals makes the logic super hard to follow. Usually, I would try to stick to either if + continue
or if ... elif ...
without continue
. Not sure if the code could be simplified here.
One thing I believe we should ensure is that the new logic does not add any unnecessary warnings. Right now, we have some unit tests to ensure that specific warnings are there, but AFAICT we don't have tests to ensure that for other cases, there are no warnings. Maybe it would be good to add tests for the "happy path" and show that there is no warning. Potentially, we can even use existing tests and just add a check there is no warning. WDYT?
test_infer_auto_device_map and test_infer_auto_device_map_with_fallback_allocation now each have a no-warning test case. Simplified and rewrote code sections that were made unreadable by the linter.
Added complete return type hinting for _init_infer_auto_device_map
Hey @BenjaminBossan, I appreciate your feedback! Regarding the use of continue and nested conditionals, I've tried simplifying the logic where possible. Now the I completely agree with your point about avoiding unnecessary warnings. I've added checks in both |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for cleaning the code up and extending the tests.
I agree that the logic was already complex beforehand so it's not just because of this PR. But I think your recent changes helped a little bit to make it easier to understand, even if the overall complexity is still high and I can't say I understand all that's going on.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice work! This will be very handy. cc @SunMarc for a final look since it's big model inference :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR @Nech-C ! Really appreciate that you are putting a lot of effort into this PR ! I will review it soon but first I have a question: could you explain a bit with an example of what this fallback allocation will do ? From our conversation last time, the biggest issue with infer_auto_device_map is that we are saving memory for the largest layer in case we need to offload it to the cpu. I think that in your case, you are trying to find a module that fits the device memory - largest layer ?
Hi @SunMarc, sure thing! You are right. My code doesn't directly address the issue that the function may reserve space on a device for a module that won't be loaded onto it during inference when there are multiple execution devices. I have tried to come up with new allocation strategies, but the task is really complex. If possible, I would like to open a separate PR to address this issue when I come up with a reasonable solution. While this PR doesn't solve the most significant concern, it does alleviate the problem. The constraint for allocating a module to a device is roughly module size + max layer size <= device memory. The aforementioned issue focuses on lowering max layer size, and this PR focuses on lowering module size by looking for a smaller module in the module list. It tries to assign a module to a device that receives no assignment when the regular allocation logic fails. In theory, a device can be used during execution if it has more memory than the largest layer, even with no module assigned to it. Thus, we can achieve the same result without going through such a roundabout approach. However, I believe this would be a breaking change, as we need the returned value Thanks for your feedback. I'm open to further suggestions or clarifications if needed. |
Nice explanation @Nech-C ! Thanks for confirming !
I think that a quick solution to the max_layer size issue would be the following algorithm
We can add your fallback option each time we run infer_auto_device_map if wanted. This will help fixing this following issue I saw a couple of time: Let me know what you think ! Nevertheless, I think it will be nice to first merge this PR before moving the max_layer size fix. |
Ohhh, now I get it @SunMarc . Thanks for breaking it down. Working on the code really helped me understand your idea. TBH, I didn't fully understand it when you first mentioned it in the issue 😅. Your algorithm idea sounds solid. I'm on board with merging this PR first, then tackling the max_layer size fix. And how should I proceed with the max_layer fix? Do I just open a new PR referencing the original issue, or do we need a new issue for this? Also, just a heads up, I've got a couple of busy weeks coming up, so I may not be able to start working on this right away. But I'll definitely get to it as soon as I can. Any tweaks you want me to make to this PR before we move on? |
Sounds good ! I'll try to review this today ! |
Hi @SunMarc, just checking in on this PR. I’ve wrapped up my busy weeks and I’m ready to move forward with the max_layer_size fix you mentioned once this gets merged. No rush at all, but I just wanted to see if there’s anything else I should address before we move forward. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks ! I'll merge this next week as I will be away and I prefer to be there in case there is an issue ! Left a minor comment. Also, can you confirm that this function behaves the same as before when we have fallback_allocation = False
, so that it is safe to merge ?
Sure thing! Here is a somewhat lengthy explanation. I'm sorry if this is too detailed, but this will help ensure we're all on the same page about the behavior. So the original logic works like this:
While this works, it introduces duplicate code and deeply nested if statements. Therefore, I tried to simplify the logic so the new logic wouldn't make it even more complicated. Here is how it works now when we have
Here are all the cases regarding the splitting and allocation behavior (T=True, F=False):
I added letters at the end of the lines of the pseudo code where the three behaviors happen: (A) for allocation, (CS) for current module split, and (TS) for tied module split. When no split happens, both implementations will go to the next device. The splitting subroutines remain the same for the new implementation to make sure we get the same behavior. I am pretty confident that it will behave the same. While there could be some edge cases that I miss here, the function will still produce a valid device map. Thanks for reviewing! This comment may be unnecessarily long, but it's the best thing I can come up with😅. Please let me know if anything needs clarification. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for providing such as detailed answer ! We are good to merge then. Could you just fix the tests since we are not logging warning anymore ?
What does this PR do?
This PR is proposed changes to the infer_auto_device_map function from #3041. It will make the following improvements:
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.