Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make Magma optional for cuda builds? #275

Open
zklaus opened this issue Oct 15, 2024 · 6 comments · May be fixed by #298
Open

Make Magma optional for cuda builds? #275

zklaus opened this issue Oct 15, 2024 · 6 comments · May be fixed by #298
Labels
question Further information is requested

Comments

@zklaus
Copy link

zklaus commented Oct 15, 2024

Comment:

A conda-forge environment with nothing but pytorch for cuda in it currently ways in at 7.2GB, which is perceived to be rather on the heavy side of things.

Looking at potential for slimming things down, libmagma at ~2GB looks like a good candidate.

The pytorch docs seem to suggest that libmagma is used as an alternative for cusolver, which is included anyway, at a much more modest 150 MB.

In the past, magma was significantly faster than cusolver, as demonstrated by [1]. However, a recent 2024 paper by the magma authors [2] shows that cusolver has made progress and is now faster for some of the most important problems.
Magma still offers significant performance benefits for certain workloads, but given that pytorch has the ability to switch between the available libraries, we could make magma an optional dependency, i.e. merely include it in run_constrained and leave it up to the user to choose space or performance optimization based on their use-case.

Is this feasible or am I missing something about the use of magma in pytorch?

Do you think this is desirable?

[1] S. Abdelfattah, A. Haidar, S. Tomov and J. Dongarra, "Analysis and Design Techniques towards High-Performance and Energy-Efficient Dense Linear Solvers on GPUs," in IEEE Transactions on Parallel and Distributed Systems, vol. 29, no. 12, pp. 2700-2712, 1 Dec. 2018, doi: 10.1109/TPDS.2018.2842785. keywords: {Graphics processing units;Energy efficiency;Task analysis;Multicore processing;Dense linear solvers;GPU computing;energy efficiency},

[2] Abdelfattah A, Beams N, Carson R, et al. MAGMA: Enabling exascale performance with accelerated BLAS and LAPACK for diverse GPU architectures. The International Journal of High Performance Computing Applications. 2024;38(5):468-490. doi:10.1177/10943420241261960

@zklaus zklaus added the question Further information is requested label Oct 15, 2024
@isuruf
Copy link
Member

isuruf commented Oct 15, 2024

we could make magma an optional dependency, i.e. merely include it in run_constrained and leave it up to the user to choose space or performance optimization based on their use-case.

If you build with magma, you need it at runtime even if you don't use magma.

@mgorny
Copy link
Contributor

mgorny commented Nov 28, 2024

Ok, I have some good news. The use of Magma is entirely limited to libtorch_cuda_linalg.so which is ~640k, and replacing Magma-enabled library with one built without Magma seems to work just fine.

I think we can split it into a separate subpackage, and build a Magma and non-Magma variants to choose from. On the minus side, we probably can't avoid building most of libtorch twice — though it could be possible to use ccache to minimize the cost of doing that.

@isuruf
Copy link
Member

isuruf commented Nov 28, 2024

The use of Magma is entirely limited to libtorch_cuda_linalg.so which is ~640k

Isn't it used by libtorch_cuda.so ?

@mgorny
Copy link
Contributor

mgorny commented Nov 28, 2024

No, it's dynamically loaded:

// In that case load library is dynamically loaded when first linalg call is made
// This helps reduce size of GPU memory context if linear algebra functions are not used

https://github.com/pytorch/pytorch/blob/a8d6afb511a69687bbb2b7e88a3cf67917e1697e/aten/src/ATen/native/cuda/LinearAlgebraStubs.cpp

mgorny added a commit to mgorny/pytorch-cpu-feedstock that referenced this issue Dec 4, 2024
Upstream keeps all magma-related routines in a separate
libtorch_cuda_linalg library that is loaded dynamically whenever linalg
functions are used.  Given the library is relatively small, splitting it
makes it possible to provide "magma" and "nomagma" variants that can
be alternated between.

Fixes conda-forge#275

Co-authored-by: Isuru Fernando <[email protected]>
@mgorny mgorny linked a pull request Dec 4, 2024 that will close this issue
5 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants