-
-
Notifications
You must be signed in to change notification settings - Fork 5.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[core] improve cpu offloading implementation #10609
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: youkaichao <[email protected]>
👋 Hi! Thank you for contributing to the vLLM project. Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can do one of these:
🚀 |
it seems the original cuda tensor is still held alive somewhere. the weights are not offloaded actually. |
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
vllm/model_executor/models/utils.py
Outdated
|
||
return module | ||
torch.empty = fake_empty |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One thing to call out is that the monkey patching here will allow you to override torch.empty
, but not any at::empty()
calls that come from C++ anywhere in the dispatcher. I'm not sure if the particular code you're running is actually running into this, but the way we normally handle "factory functions that you want to override to return tensor subclasses" is with a TorchDispatchMode
:
from torch.utils._python_dispatch import TorchDispatchMode
class OffloadedTensorMode(TorchDispatchMode):
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
rs = func(*args, **kwargs)
if func is torch.ops.aten.empty.memory_format and rs.device != "cpu" and ...:
rs = OffloadedTensor(rs)
return rs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is very helpful!
vllm/model_executor/models/utils.py
Outdated
if requires_grad is None: | ||
return super().__new__(cls, elem) | ||
else: | ||
return cls._make_subclass(cls, elem, requires_grad) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think if your tensor subclass internally holds another tensor (elem
here), you probably want to user the "wrapper" subclass API. Example here
out_tensor = torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs)
out_tensor.elem = weak_ref_tensor(elem)
Side note: I would probably call that constructor unconditionally, any reason you aren't doing it when requires_grad is None
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
… use tree_map && add handle aten.uniform_.default && rm handle aten.slice.Tensor &&
class OffloadedTensor(torch.Tensor): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We do generally have support for subclasses that implement both torch_function and torch_dispatch, although if you only need torch_dispatch then I agree that you probably want to disable torch_function as linked above.
Let me know if you have any other questions / would like to chat more about the subclass work you're doing!
vllm/model_executor/models/utils.py
Outdated
tensor = func(*args, **kwargs) | ||
|
||
if (func is torch.ops.aten.empty.memory_format | ||
and tensor.device != "cpu"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe use torch.device("cpu")
instead of "cpu"
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm tensor.device != "cpu
should generally be ok
This pull request has merge conflicts that must be resolved before it can be |
make it friendly with
torch.compile