Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speed up imports and add a CI #2845

Merged
merged 14 commits into from
Jul 1, 2024
Merged

Speed up imports and add a CI #2845

merged 14 commits into from
Jul 1, 2024

Conversation

muellerzr
Copy link
Collaborator

@muellerzr muellerzr commented Jun 11, 2024

What does this PR do?

This PR introduces a CI which utilizes the cProfile timings and my tuna-interpreter libraries to perform time-based import tests.

It was reported by @stas00 that we were taking far too long to do basic things like accelerate launch, and tuna can help visualize why by creating import graphs directing us to what is taking too long:

image

I wrote a small library called tuna-interpreter that aims to take the best parts of tuna and work it into something parse-able that lets us run CIs off of it.

After using the tool:

image

We can see a decrease of ~68%

How it works:

In its current form, we are going based off of a baseline torch import, since Accelerate relies on torch no matter what. BUT we should be no more than ~20% slower than the torch import overall. Anything more and we have some slip-up or timing problem.

An example test looks like the following:

    def test_base_import(self):
        output = run_import_time("import accelerate")
        with open(f"{self.tmpdir}/base_results.log", "w") as f:
            f.write(output)
        data = read_import_profile(f"{self.tmpdir}/base_results.log")
        total_time = calculate_total_time(data)
        pct_more = total_time / self.pytorch_time
        # Base import should never be more than 10% slower than raw torch import
        err_msg = f"Base import is more than 20% slower than raw torch import ({pct_more * 100:.2f}%), please check the attached `tuna` profile:\n"
        sorted_data = sort_nodes_by_total_time(data)
        paths_above_threshold = get_paths_above_threshold(sorted_data, 0.1, max_depth=7)
        err_msg += f"\n{convert_list_to_string(paths_above_threshold)}"
        self.assertLess(pct_more, 1.2, err_msg)

Where essentially we:

  1. Get the import_time to run a particular python import called via -c in subprocess
  2. We then read this profile generated
  3. From here, we take all the nodes, sort them by time, and get all paths above an arbitrary threshold. This should be tweaked to your own discretion, as threshold and max_depth changes from library to library. The key with max_depth is it should be enough to get your imports out of the slowdown trace, and show what external libraries you are really calling.
  4. Afterwards, we write a note stating that it was above a slowdown expected %, and state what modules were slowing it down.

An example failure is below, where we can clearly see what module chain had the slowdown:

E       AssertionError: 1.3515017627366224 not less than 1.2 : Base import is more than 20% slower than raw torch import (135.15%), please check the attached `tuna` profile:
E       
E       main 0.973s
E       main->accelerate 0.961s
E       main->accelerate->accelerate.accelerator 0.959s
E       main->accelerate->accelerate.accelerator->torch 0.758s
E       main->accelerate->accelerate.accelerator->torch->torch._C 0.355s
E       main->accelerate->accelerate.accelerator->torch->torch._meta_registrations 0.154s
E       main->accelerate->accelerate.accelerator->torch->torch._meta_registrations->torch._decomp 0.109s
E       main->accelerate->accelerate.accelerator->accelerate.checkpointing 0.126s
E       main->accelerate->accelerate.accelerator->accelerate.checkpointing->accelerate.utils 0.125s
E       main->accelerate->accelerate.accelerator->accelerate.checkpointing->accelerate.utils->accelerate.utils.megatron_lm 0.107s
E       main->accelerate->accelerate.accelerator->accelerate.checkpointing->accelerate.utils->accelerate.utils.megatron_lm->transformers.modeling_outputs 0.106s

tests/test_imports.py:64: AssertionError

Or:

E       AssertionError: 1.8292819779293377 not less than 1.2 : Base import is more than 20% slower than raw torch import (182.93%), please check the attached `tuna` profile:
E       
E       main 1.324s
E       main->accelerate 1.308s
E       main->accelerate->accelerate.accelerator 1.307s
E       main->accelerate->accelerate.accelerator->torch 0.706s
E       main->accelerate->accelerate.accelerator->torch->torch._C 0.327s
E       main->accelerate->accelerate.accelerator->torch->torch._meta_registrations 0.152s
E       main->accelerate->accelerate.accelerator->torch->torch._meta_registrations->torch._decomp 0.108s
E       main->accelerate->accelerate.accelerator->accelerate.checkpointing 0.527s
E       main->accelerate->accelerate.accelerator->accelerate.checkpointing->accelerate.utils 0.526s
E       main->accelerate->accelerate.accelerator->accelerate.checkpointing->accelerate.utils->accelerate.utils.fsdp_utils 0.488s
E       main->accelerate->accelerate.accelerator->accelerate.checkpointing->accelerate.utils->accelerate.utils.fsdp_utils->torch.distributed.fsdp.fully_sharded_data_parallel 0.488s
E       main->accelerate->accelerate.accelerator->accelerate.checkpointing->accelerate.utils->accelerate.utils.fsdp_utils->torch.distributed.fsdp.fully_sharded_data_parallel->torch.distributed.fsdp 0.488s

tests/test_imports.py:64: AssertionError

If there are specific issues with using tuna-interpreter, please let me know, it's a very quickly hacked-together-but-working library for what we are doing, and open to improving it further after we battle-test it

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@SunMarc @BenjaminBossan @sayakpaul @Titus-von-Koeller @ArthurZucker @ydshieh @LysandreJik

@muellerzr
Copy link
Collaborator Author

Let me know how I can improve on this tool further so we can then get it going throughout anyone at HF that wants to use it 🤗

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this workflow to detect regressions in import time.

I have a a couple of questions and comments, please check.

Regarding tuna, I did a quick check and from my understanding, this is mostly a package to lightly process the profile data generated by Python's cProfile module and to add visualization similar to snakeviz. Also note that the last commit is already almost 1.5 years old and that read_import_profile comes from a private module. For this reason, I wonder if it wouldn't make more sense to vendor this function with your lib and get rid of the tuna dependency completely.

tests/test_imports.py Outdated Show resolved Hide resolved
sorted_data = sort_nodes_by_total_time(data)
paths_above_threshold = get_paths_above_threshold(sorted_data, 0.1, max_depth=7)
err_msg += f"\n{convert_list_to_string(paths_above_threshold)}"
self.assertLess(pct_more, 1.2, err_msg)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So IIUC, we want to ensure that the import doesn't add more than 20% on top of the PyTorch import. Did you check how robust this is? Should this test perhaps do multiple imports and then average? If so, should we throw away the first timing, as it may require a warm up (disk cache?) and thus make the first import unduly slow (and thus make the test pass trivially)?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct.

We specifically care about the warmup, because otherwise it will lead to slowdowns via accelerate launch.

I'll check multiple calls, but it shouldn't because everything is run via subprocess so it shouldn't rely on disk caching

from .megatron_lm import prepare_scheduler as megatron_lm_prepare_scheduler


if is_megatron_lm_available():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For my understanding, this change is unrelated, right?

Copy link
Collaborator Author

@muellerzr muellerzr Jun 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's fully related, see the megatron-related graph in the PR description (in the test examples), it was a source of slowdown

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I missed that, had to scroll right to see the megatron part :D So it's not related to implementing the new workflow, but it is related to actually increasing the startup time.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes exactly. Can't fully do one without the other for good baselines, hence why they are both in this PR 😅

Copy link
Contributor

@stas00 stas00 Jun 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the check name is somewhat misleading, since a user may have megatron installed (i.e. available) but they aren't telling accelerate to use it, and so it shouldn't load it then.

perhaps calling it is_megatron_needed?

for the available logic - shouldn't it fail anyway long before this code if it's not available?

what I'm trying to communicate is for the sake of imports. needed and available are orthogonal to each other.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, it is always available and has been available, else we'd call it needed or configured as you say.

requires_ is for tests, we know we are using it so thus we use it.

_available simply means bring it into the environment if available.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Its aggressive importing, yes. Which as noted in my last comment, will aim to fix when I can dedicate some time to refactoring our entire import schema in accelerate

Copy link
Contributor

@stas00 stas00 Jun 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but the first thing the is_megatron_available check does is whether it's configured to be used, or am I misreading it?

def is_megatron_lm_available():
if str_to_bool(os.environ.get("ACCELERATE_USE_MEGATRON_LM", "False")) == 1:

so it's not available - it's is_wanted_and_available

@muellerzr
Copy link
Collaborator Author

muellerzr commented Jun 11, 2024

@BenjaminBossan what's missing here is the initial tuna check, which is why it's required (see the workflow).

We can eventually look at gutting it out, sure. However I think having both the visual option for debugging further + the condensed output here is valuable and is why we should still use tuna itself.

Edit; I may be wrong here, sorry

@muellerzr
Copy link
Collaborator Author

@BenjaminBossan fully removed the requirement for tuna there, however I still think it's useful for further debugging so left that in as part of the testcase class description :)

Comment on lines +49 to +52
import torch.distributed.checkpoint as dist_cp
from torch.distributed.checkpoint.default_planner import DefaultSavePlanner
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
Copy link
Contributor

@stas00 stas00 Jun 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this feels somewhat yucky - since it's going to run the same imports many times over during the life of a trainer - and more so because you replicate this exact import code multiple times below in similar functions.

Why not move these imports into its own file and call the imports only if they haven't been loaded yet?

Moreover it should be safe to assume that if a user configured accelerate to use FSDP it'll need all these save/load fsdp functions - so from the performance point of view all of them can be moved into a single place and loaded once - if the user asked for FSDP. i.e. outside of these functions. - similar to how you did it with megatron in this PR.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, it's gross for now. I plan on working with @LysandreJik for adding a lazy-import style next week to fix all of these performance issues fully, which can include doing the solution you suggest (as we have to do this in multiple areas, not ideal)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will need to think of an ideal folder structure, but I'll start with some small version of this in this PR, since we're touching those imports now (specifically sectioning out to a different file)

Copy link
Collaborator Author

@muellerzr muellerzr Jun 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also the difference with megatron is we check for a package being installed. So while we can do similar for deepspeed, we cannot do this for fsdp (as it doesn't actually check for configurations, only if a package is available in the env, nor should we check based on configuration alone IMO as some libraries do not rely on this)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but it should be enough to check whether FSDP is being configured to be used - configured? import, not configured? don't import

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the dependency chain could be:

this_component_is_needed_to_continue -> do we have it installed -> if so import the needed bits.

i.e. the logic that decides to whether import something or not is driven by the configuration of accelerate and not by availability of a 3rd party package.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, but imports are cached in Python, so the 2nd time this is called, they don't add any overhead. Or am I missing something?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, but this just feels weird.

imports don't belong to methods/functions and surely not when you repeat the same imports multiple times in multiple functions. Nothing stops you from doing that, of course.

imports are sort of a singleton thing and certainly we occasionally see an odd import in some function, but it's usually an exception and this happens more often in a throw-away code.

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall this LGTM, thanks for working on this and the tuna-interpreter package.

I have some comments, but none are blockers. Regarding the local imports, personally I don't think it's a huge anti-pattern, so this can be left for later IMO.

One question I have is how test_imports.py is supposed to be run, just manually by devs or some automatic job? If the latter, I'd consider running the checks multiple times and averaging the results (perhaps also discarding outliers) to make the job more robust. If it's manual, then that's not necessary.

src/accelerate/test_utils/testing.py Outdated Show resolved Hide resolved
def setUpClass(cls):
super().setUpClass()
output = run_import_time("import torch")
with open(cls.tmpdir / "pytorch_results.log", "w") as f:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since you now control read_import_profile, you could make it so that it can also receive an iterable as input instead of a file name, that way you can get rid of the indirection of writing the outputs to a file first.

Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for making the import faster !

src/accelerate/test_utils/testing.py Outdated Show resolved Hide resolved
Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good one! Good news for the diffusers team is that we already follow this mechanism :)

But we could definitely include a CI for this. @yiyixuxu @DN6 WDYT?

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@muellerzr muellerzr merged commit 3086e26 into main Jul 1, 2024
28 checks passed
@muellerzr muellerzr deleted the speedup-imports branch July 1, 2024 22:50
@ydshieh
Copy link
Contributor

ydshieh commented Jul 2, 2024

🔥

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants