Skip to content

Commit

Permalink
[WIP] Use libblastrampoline to forward to a user-defined BLAS at runt…
Browse files Browse the repository at this point in the history
…ime (#39455)
  • Loading branch information
ViralBShah authored Feb 25, 2021
1 parent 384c94a commit a0efe87
Showing 1 changed file with 5 additions and 27 deletions.
32 changes: 5 additions & 27 deletions test/distributed_exec.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1005,30 +1005,8 @@ end

# Test addprocs enable_threaded_blas parameter

const get_num_threads = function() # anonymous so it will be serialized when called
blas = LinearAlgebra.BLAS.vendor()
# Wrap in a try to catch unsupported blas versions
try
if blas == :openblas
return ccall((:openblas_get_num_threads, Base.libblas_name), Cint, ())
elseif blas == :openblas64
return ccall((:openblas_get_num_threads64_, Base.libblas_name), Cint, ())
elseif blas == :mkl
return ccall((:MKL_Get_Max_Num_Threads, Base.libblas_name), Cint, ())
end

# OSX BLAS looks at an environment variable
if Sys.isapple()
return tryparse(Cint, get(ENV, "VECLIB_MAXIMUM_THREADS", "1"))
end
catch
end

return nothing
end

function get_remote_num_threads(processes_added)
return [remotecall_fetch(get_num_threads, proc_id) for proc_id in processes_added]
return [remotecall_fetch(BLAS.get_num_threads, proc_id) for proc_id in processes_added]
end

function test_blas_config(pid, expected)
Expand All @@ -1041,7 +1019,7 @@ function test_blas_config(pid, expected)
end

function test_add_procs_threaded_blas()
master_blas_thread_count = get_num_threads()
master_blas_thread_count = BLAS.get_num_threads()
if master_blas_thread_count === nothing
@warn "Skipping blas num threads tests due to unsupported blas version"
return
Expand All @@ -1055,7 +1033,7 @@ function test_add_procs_threaded_blas()
end

# Master thread should not have changed
@test get_num_threads() == master_blas_thread_count
@test BLAS.get_num_threads() == master_blas_thread_count

# Threading disabled in children by default
thread_counts_by_process = get_remote_num_threads(processes_added)
Expand All @@ -1069,9 +1047,9 @@ function test_add_procs_threaded_blas()
test_blas_config(proc_id, true)
end

@test get_num_threads() == master_blas_thread_count
@test BLAS.get_num_threads() == master_blas_thread_count

# BLAS.set_num_threads(`num`) doesn't cause get_num_threads to return `num`
# BLAS.set_num_threads(`num`) doesn't cause BLAS.get_num_threads to return `num`
# depending on the machine, the BLAS version, and BLAS configuration, so
# we need a very lenient test.
thread_counts_by_process = get_remote_num_threads(processes_added)
Expand Down

0 comments on commit a0efe87

Please sign in to comment.