-
Notifications
You must be signed in to change notification settings - Fork 31
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enable the use of PyTorch 'mps' device for inference on Apple Silicon #5
Comments
Related: when I try to run the
|
I tried getting past the OP issue ( found and fixed this issue huggingface/accelerate#1297 along the way
But then I get this puzzling error:
I couldn't find much by googling for this, not helped by the fact that PyTorch itself seems to be part of Meta (Facebook) now There is this https://pytorch.org/docs/stable/generated/torch.Tensor.is_meta.html and https://pytorch.org/torchdistx/latest/fake_tensor.html
Not sure why we are on the It sounds like maybe Adding an explicit device map in device = torch.device("cpu")
if torch.has_cuda:
device = torch.device("cuda")
elif torch.has_mps:
device = torch.device('mps')
if device_map is None:
modules = (
"transformer",
"tok_embeddings",
"layers",
"norm",
"output",
)
device_map = {module: device for module in modules} ...this gets further! Now I get this error:
Finally I tried with the
which starts to look like running on MPS device is a dead-end |
Some great sleuthing there @anentropic 👌 that last error is mentioned in this thread |
Others have got this working it seems: |
The LLaMA.cpp project enables LLaMA inference on Apple Silicon devices by using CPU, but faster inference should be possible by supporting the M1/Pro/Max GPU on
vanilla-llama
, given that PyTorch is now M1 compatible using the 'mps' device.I'm new to Python but my observations:
In both
generation.py
andmodel.py
there are uses of the function.cuda()
which can be replaced withWhen attempting to run
example.py
after this; it's the Accelerate framework which throws an error with:RuntimeError: "addmm_impl_cpu_" not implemented for 'Half'
- something is trying to usecpu
instead ofmps
.I wonder if this is because the call into accelerate is
load_checkpoint_and_dispatch
withauto
provided as the device map - is PyTorch preferringcpu
overmps
here for some reason. Edit: ThisThe text was updated successfully, but these errors were encountered: