Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add warnings and fallback for unassigned devices in infer_auto_device_map #3066

Merged
merged 15 commits into from
Nov 20, 2024

Conversation

Nech-C
Copy link
Contributor

@Nech-C Nech-C commented Sep 1, 2024

What does this PR do?

This PR is proposed changes to the infer_auto_device_map function from #3041. It will make the following improvements:

  1. Add warnings when no modules are assigned to a main device due to low max_memory.
  2. Report the minimum memory needed for at least one module assignment with the warnings. For example, according to the current logic, this value will be the (first immediate non-splittable module) + (the largest layer) for the first device.
  3. Add a new parameter fallback_allocation. When set to True, it will attempt an alternative assignment if max_memory is sufficient for some (non-splittable module) + (largest layer) but insufficient for the default assignment attempt. This makes sure at least one module is assigned to the potential execution device and likely won't break other code.

Before submitting

  • Did you read the contributor guideline,
    Pull Request section?
  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@muellerzr muellerzr requested a review from SunMarc September 1, 2024 18:44
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

The fallback allocation will be reintroduced once the branching logic is fully refactored. This commit prepares the function infer_auto_device_map for further refactoring.
Implemented fallback allocation to allow modules to be allocated to devices using BFS when regular allocation fails. This enhancement improves the allocation process by ensuring that at least one module is assigned to the device, even under tight memory constraints.
@Nech-C Nech-C marked this pull request as ready for review October 14, 2024 01:40
Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the updates and the style fix. I'm not very knowledgeable about the whole logic being applied here, so I won't comment on that.

Personally, I find the use of continue in addition to many nested conditionals makes the logic super hard to follow. Usually, I would try to stick to either if + continue or if ... elif ... without continue. Not sure if the code could be simplified here.

One thing I believe we should ensure is that the new logic does not add any unnecessary warnings. Right now, we have some unit tests to ensure that specific warnings are there, but AFAICT we don't have tests to ensure that for other cases, there are no warnings. Maybe it would be good to add tests for the "happy path" and show that there is no warning. Potentially, we can even use existing tests and just add a check there is no warning. WDYT?

src/accelerate/utils/modeling.py Outdated Show resolved Hide resolved
tests/test_modeling_utils.py Outdated Show resolved Hide resolved
test_infer_auto_device_map and test_infer_auto_device_map_with_fallback_allocation now each have a no-warning test case.

Simplified and rewrote code sections that were made unreadable by the linter.
Added complete return type hinting for _init_infer_auto_device_map
@Nech-C
Copy link
Contributor Author

Nech-C commented Oct 14, 2024

Thanks for the updates and the style fix. I'm not very knowledgeable about the whole logic being applied here, so I won't comment on that.

Personally, I find the use of continue in addition to many nested conditionals makes the logic super hard to follow. Usually, I would try to stick to either if + continue or if ... elif ... without continue. Not sure if the code could be simplified here.

One thing I believe we should ensure is that the new logic does not add any unnecessary warnings. Right now, we have some unit tests to ensure that specific warnings are there, but AFAICT we don't have tests to ensure that for other cases, there are no warnings. Maybe it would be good to add tests for the "happy path" and show that there is no warning. Potentially, we can even use existing tests and just add a check there is no warning. WDYT?

Hey @BenjaminBossan, I appreciate your feedback!

Regarding the use of continue and nested conditionals, I've tried simplifying the logic where possible. Now the while loop in infer_auto_device_map uses if + continue for branching. However, there are more continue statements in the code, and many of them come from the original implementation. If you think it's necessary to address those, I will take a look at those and see what I can do.

I completely agree with your point about avoiding unnecessary warnings. I've added checks in both test_infer_auto_device_map_with_fallback_allocation and test_infer_auto_device_map to verify that no unexpected warnings are raised in the 'happy path' cases.

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for cleaning the code up and extending the tests.

I agree that the logic was already complex beforehand so it's not just because of this PR. But I think your recent changes helped a little bit to make it easier to understand, even if the overall complexity is still high and I can't say I understand all that's going on.

Copy link
Collaborator

@muellerzr muellerzr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work! This will be very handy. cc @SunMarc for a final look since it's big model inference :)

Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR @Nech-C ! Really appreciate that you are putting a lot of effort into this PR ! I will review it soon but first I have a question: could you explain a bit with an example of what this fallback allocation will do ? From our conversation last time, the biggest issue with infer_auto_device_map is that we are saving memory for the largest layer in case we need to offload it to the cpu. I think that in your case, you are trying to find a module that fits the device memory - largest layer ?

@Nech-C
Copy link
Contributor Author

Nech-C commented Oct 15, 2024

Thanks for the PR @Nech-C ! Really appreciate that you are putting a lot of effort into this PR ! I will review it soon but first I have a question: could you explain a bit with an example of what this fallback allocation will do ? From our conversation last time, the biggest issue with infer_auto_device_map is that we are saving memory for the largest layer in case we need to offload it to the cpu. I think that in your case, you are trying to find a module that fits the device memory - largest layer ?

Hi @SunMarc, sure thing!
Let's consider a model on which we run infer_auto_device_map. The model has three non-splittable modules: A, B, and C, with sizes 2, 1, and 3, respectively. max_memory is defined as {0: 4, "cpu": 6}. Without fallback allocation, the entire model will be allocated to the CPU (main memory) because the combined size of A and C (2+3 = 5) exceeds the memory limit of device 0 (=4), so device 0 gets skipped. When fallback_allocation is set to True, it uses DFS to find a module in the model that satisfies the size constraint (module size + largest layer size < max memory). The resultant device map will look like this: {A: 0, B: "cpu", C: "cpu"}. This way, the faster hardware can be used during the inference without causing OOM.

You are right. My code doesn't directly address the issue that the function may reserve space on a device for a module that won't be loaded onto it during inference when there are multiple execution devices. I have tried to come up with new allocation strategies, but the task is really complex. If possible, I would like to open a separate PR to address this issue when I come up with a reasonable solution.

While this PR doesn't solve the most significant concern, it does alleviate the problem. The constraint for allocating a module to a device is roughly module size + max layer size <= device memory. The aforementioned issue focuses on lowering max layer size, and this PR focuses on lowering module size by looking for a smaller module in the module list. It tries to assign a module to a device that receives no assignment when the regular allocation logic fails. In theory, a device can be used during execution if it has more memory than the largest layer, even with no module assigned to it. Thus, we can achieve the same result without going through such a roundabout approach. However, I believe this would be a breaking change, as we need the returned value device_map to include this information, and it also requires considerable changes in other code, such as the dispatch_model function.

Thanks for your feedback. I'm open to further suggestions or clarifications if needed.

@SunMarc
Copy link
Member

SunMarc commented Oct 16, 2024

Nice explanation @Nech-C ! Thanks for confirming !

The constraint for allocating a module to a device is roughly module size + max layer size <= device memory. The aforementioned issue focuses on lowering max layer size, and this PR focuses on lowering module size by looking for a smaller module in the module list.

I think that a quick solution to the max_layer size issue would be the following algorithm

    1. run the infer_auto_device_map with max layer size = 0
    1. Check if we have offloaded layers.
      a) If not, we have our final device_map
      b) Else, we redo the computation without removing max layer size

We can add your fallback option each time we run infer_auto_device_map if wanted.

This will help fixing this following issue I saw a couple of time:
The model has three non-splittable modules: A, B, and C, with sizes 2, 1, and 3, respectively. max_memory is defined as {0: 4, 1:10, "cpu": 6}. With the current flow, the device_map will be {A: 1, B: 1, C: 1} since A+C = 5 > 4 whereas with the above algorithm, we will have {A: 0, B: 0, C: 1}. Your fallback option could help in the first iteration if A size is 5. In this case, the end device_map would be {A: 1, B: 0, C: 0} instead of {A: 1, B: 1, C: 1}.

Let me know what you think !

Nevertheless, I think it will be nice to first merge this PR before moving the max_layer size fix.

@Nech-C
Copy link
Contributor Author

Nech-C commented Oct 16, 2024

Nice explanation @Nech-C ! Thanks for confirming !

The constraint for allocating a module to a device is roughly module size + max layer size <= device memory. The aforementioned issue focuses on lowering max layer size, and this PR focuses on lowering module size by looking for a smaller module in the module list.

I think that a quick solution to the max_layer size issue would be the following algorithm

    1. run the infer_auto_device_map with max layer size = 0
    1. Check if we have offloaded layers.
      a) If not, we have our final device_map
      b) Else, we redo the computation without removing max layer size

We can add your fallback option each time we run infer_auto_device_map if wanted.

When running the infer_auto_device_map, we can add the fallback option of course.

This will help fixing this following issue I saw a couple of time: The model has three non-splittable modules: A, B, and C, with sizes 2, 1, and 3, respectively. max_memory is defined as {0: 4, 1:10, "cpu": 6}. With the current flow, the device_map will be {A: 1, B: 1, C: 1} since A+C = 5 > 4 whereas with the above algorithm, we will have {A: 0, B: 0, C: 1}. Your fallback option could help in the first iteration if A size is 5. In this case, the end device_map would be {A: 1, B: 0, C: 0} instead of {A: 1, B: 1, C: 1}.

Let me know what you think !

Nevertheless, I think it will be nice to first merge this PR before moving the max_layer size fix.

Ohhh, now I get it @SunMarc . Thanks for breaking it down. Working on the code really helped me understand your idea. TBH, I didn't fully understand it when you first mentioned it in the issue 😅. Your algorithm idea sounds solid. I'm on board with merging this PR first, then tackling the max_layer size fix.

And how should I proceed with the max_layer fix? Do I just open a new PR referencing the original issue, or do we need a new issue for this?

Also, just a heads up, I've got a couple of busy weeks coming up, so I may not be able to start working on this right away. But I'll definitely get to it as soon as I can.

Any tweaks you want me to make to this PR before we move on?

@SunMarc
Copy link
Member

SunMarc commented Oct 17, 2024

Also, just a heads up, I've got a couple of busy weeks coming up, so I may not be able to start working on this right away. But I'll definitely get to it as soon as I can.

Any tweaks you want me to make to this PR before we move on?

Sounds good ! I'll try to review this today !

@Nech-C
Copy link
Contributor Author

Nech-C commented Oct 30, 2024

Hi @SunMarc, just checking in on this PR. I’ve wrapped up my busy weeks and I’m ready to move forward with the max_layer_size fix you mentioned once this gets merged. No rush at all, but I just wanted to see if there’s anything else I should address before we move forward.

Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks ! I'll merge this next week as I will be away and I prefer to be there in case there is an issue ! Left a minor comment. Also, can you confirm that this function behaves the same as before when we have fallback_allocation = False, so that it is safe to merge ?

src/accelerate/utils/modeling.py Outdated Show resolved Hide resolved
@Nech-C
Copy link
Contributor Author

Nech-C commented Nov 6, 2024

Thanks ! I'll merge this next week as I will be away and I prefer to be there in case there is an issue ! Left a minor comment. Also, can you confirm that this function behaves the same as before when we have fallback_allocation = False, so that it is safe to merge ?

Sure thing! Here is a somewhat lengthy explanation. I'm sorry if this is too detailed, but this will help ensure we're all on the same page about the behavior.

So the original logic works like this:

while the module_list is not empty:
    current_module = module_list.pop()
    if current_module is too big:
         if can't split:
             move on to the next device 
         else:
              split it and go to next iter (CS)
    elif current_module fits and has tie_params:
        if everything fits:
            allocate (A)
        else:
            try splitting the tied_modules (TS)
            if split happened:
               go to next iter
            else:
                move on to the next device
    else: # current module fits with no tied_params
        allocate (A)

While this works, it introduces duplicate code and deeply nested if statements. Therefore, I tried to simplify the logic so the new logic wouldn't make it even more complicated. Here is how it works now when we have fallback_allocation = False:

while the module_list is not empty:
    current_module = module_list.pop()
    if the current_module and its tied_modules fit:
       allocate (A)
    if the current_module fits but not the tied_modules:
        try splitting the tied_modules (TS)
        if split happened:
            go to next iter 
    if the current_module doesn't fit
        try splitting the current_module (CS)
        if split happened:
            go to next iter
    # at this point, no allocation nor split has happened
    move on to the next device 

Here are all the cases regarding the splitting and allocation behavior (T=True, F=False):

module size <= memory tied module size <= memory total size <= memory old implementation new implementation
T T T allocate allocate
T T F split tied modules split tied modules
T F T N/A N/A
T F F split tied modules split tied modules
F T T N/A N/A
F T F split current module split current module
F F T N/A N/A
F F F split current module split current module

I added letters at the end of the lines of the pseudo code where the three behaviors happen: (A) for allocation, (CS) for current module split, and (TS) for tied module split. When no split happens, both implementations will go to the next device. The splitting subroutines remain the same for the new implementation to make sure we get the same behavior. I am pretty confident that it will behave the same. While there could be some edge cases that I miss here, the function will still produce a valid device map.

Thanks for reviewing! This comment may be unnecessarily long, but it's the best thing I can come up with😅. Please let me know if anything needs clarification.

Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for providing such as detailed answer ! We are good to merge then. Could you just fix the tests since we are not logging warning anymore ?

tests/test_modeling_utils.py Outdated Show resolved Hide resolved
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants