Skip to content

Commit

Permalink
allow ccall library name to be non-constant (#37123)
Browse files Browse the repository at this point in the history
Fixes #36458
  • Loading branch information
JeffBezanson authored Sep 2, 2020
1 parent 0bff5bf commit 5cba88c
Show file tree
Hide file tree
Showing 10 changed files with 148 additions and 28 deletions.
3 changes: 3 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ New language features
and associativity as other arrow-like operators ([#36666]).
* Compilation and type inference can now be enabled or disabled at the module level
using the experimental macro `Base.Experimental.@compiler_options` ([#37041]).
* The library name passed to `ccall` or `@ccall` can now be an expression involving
global variables and function calls. The expression will be evaluated the first
time the `ccall` executes ([#36458]).

Language changes
----------------
Expand Down
10 changes: 9 additions & 1 deletion doc/src/manual/calling-c-and-fortran-code.md
Original file line number Diff line number Diff line change
Expand Up @@ -878,7 +878,15 @@ it must be handled in other ways.

## Non-constant Function Specifications

A `(name, library)` function specification must be a constant expression. However, it is possible
In some cases, the exact name or path of the needed library is not known in advance and must
be computed at run time. To handle such cases, the library component of a `(name, library)`
specification can be a function call, e.g. `(:dgemm_, find_blas())`. The call expression will
be executed when the `ccall` itself is executed. However, it is assumed that the library
location does not change once it is determined, so the result of the call can be cached and
reused. Therefore, the number of times the expression executes is undefined, and returning
different values for multiple calls results in undefined behavior.

If even more flexibility is needed, it is possible
to use computed values as function names by staging through [`eval`](@ref) as follows:

```
Expand Down
8 changes: 8 additions & 0 deletions src/ast.scm
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,9 @@
(define (ssavalue? e)
(and (pair? e) (eq? (car e) 'ssavalue)))

(define (slot? e)
(and (pair? e) (eq? (car e) 'slot)))

(define (globalref? e)
(and (pair? e) (eq? (car e) 'globalref)))

Expand Down Expand Up @@ -439,6 +442,11 @@
(let ((x (cadr e)))
(not (simple-atom? x)))))

(define (tuple-call? e)
(and (length> e 1)
(eq? (car e) 'call)
(equal? (cadr e) '(core tuple))))

(define (eq-sym? a b)
(or (eq? a b) (and (ssavalue? a) (ssavalue? b) (eqv? (cdr a) (cdr b)))))

Expand Down
103 changes: 83 additions & 20 deletions src/ccall.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@ static bool runtime_sym_gvs(jl_codegen_params_t &emission_context, const char *f
static Value *runtime_sym_lookup(
jl_codegen_params_t &emission_context,
IRBuilder<> &irbuilder,
PointerType *funcptype, const char *f_lib,
jl_codectx_t *ctx,
PointerType *funcptype, const char *f_lib, jl_value_t *lib_expr,
const char *f_name, Function *f,
GlobalVariable *libptrgv,
GlobalVariable *llvmgv, bool runtime_lib)
Expand Down Expand Up @@ -106,16 +107,25 @@ static Value *runtime_sym_lookup(
assert(f->getParent() != NULL);
f->getBasicBlockList().push_back(dlsym_lookup);
irbuilder.SetInsertPoint(dlsym_lookup);
Value *libname;
if (runtime_lib) {
libname = stringConstPtr(emission_context, irbuilder, f_lib);
Instruction *llvmf;
Value *nameval = stringConstPtr(emission_context, irbuilder, f_name);
if (lib_expr) {
jl_cgval_t libval = emit_expr(*ctx, lib_expr);
llvmf = irbuilder.CreateCall(prepare_call_in(jl_builderModule(irbuilder), jllazydlsym_func),
{ boxed(*ctx, libval), nameval });
}
else {
// f_lib is actually one of the special sentinel values
libname = ConstantExpr::getIntToPtr(ConstantInt::get(T_size, (uintptr_t)f_lib), T_pint8);
Value *libname;
if (runtime_lib) {
libname = stringConstPtr(emission_context, irbuilder, f_lib);
}
else {
// f_lib is actually one of the special sentinel values
libname = ConstantExpr::getIntToPtr(ConstantInt::get(T_size, (uintptr_t)f_lib), T_pint8);
}
llvmf = irbuilder.CreateCall(prepare_call_in(jl_builderModule(irbuilder), jldlsym_func),
{ libname, nameval, libptrgv });
}
Value *llvmf = irbuilder.CreateCall(prepare_call_in(jl_builderModule(irbuilder), jldlsym_func),
{ libname, stringConstPtr(emission_context, irbuilder, f_name), libptrgv });
StoreInst *store = irbuilder.CreateAlignedStore(llvmf, llvmgv, Align(sizeof(void*)));
store->setAtomic(AtomicOrdering::Release);
irbuilder.CreateBr(ccall_bb);
Expand All @@ -124,21 +134,49 @@ static Value *runtime_sym_lookup(
irbuilder.SetInsertPoint(ccall_bb);
PHINode *p = irbuilder.CreatePHI(T_pvoidfunc, 2);
p->addIncoming(llvmf_orig, enter_bb);
p->addIncoming(llvmf, dlsym_lookup);
p->addIncoming(llvmf, llvmf->getParent());
return irbuilder.CreateBitCast(p, funcptype);
}

static Value *runtime_sym_lookup(
jl_codectx_t &ctx,
PointerType *funcptype, const char *f_lib,
PointerType *funcptype, const char *f_lib, jl_value_t *lib_expr,
const char *f_name, Function *f,
GlobalVariable *libptrgv,
GlobalVariable *llvmgv, bool runtime_lib)
{
return runtime_sym_lookup(ctx.emission_context, ctx.builder, &ctx, funcptype, f_lib, lib_expr,
f_name, f, libptrgv, llvmgv, runtime_lib);
}

static Value *runtime_sym_lookup(
jl_codectx_t &ctx,
PointerType *funcptype, const char *f_lib, jl_value_t *lib_expr,
const char *f_name, Function *f)
{
GlobalVariable *libptrgv;
GlobalVariable *llvmgv;
bool runtime_lib = runtime_sym_gvs(ctx.emission_context, f_lib, f_name, libptrgv, llvmgv);
libptrgv = prepare_global_in(jl_Module, libptrgv);
bool runtime_lib;
if (lib_expr) {
// for computed library names, generate a global variable to cache the function
// pointer just for this call site.
runtime_lib = true;
libptrgv = NULL;
std::string gvname = "libname_";
gvname += f_name;
gvname += "_";
gvname += std::to_string(globalUnique++);
Module *M = ctx.emission_context.shared_module(jl_LLVMContext);
llvmgv = new GlobalVariable(*M, T_pvoidfunc, false,
GlobalVariable::ExternalLinkage,
Constant::getNullValue(T_pvoidfunc), gvname);
}
else {
runtime_lib = runtime_sym_gvs(ctx.emission_context, f_lib, f_name, libptrgv, llvmgv);
libptrgv = prepare_global_in(jl_Module, libptrgv);
}
llvmgv = prepare_global_in(jl_Module, llvmgv);
return runtime_sym_lookup(ctx.emission_context, ctx.builder, funcptype, f_lib, f_name, f, libptrgv, llvmgv, runtime_lib);
return runtime_sym_lookup(ctx, funcptype, f_lib, lib_expr, f_name, f, libptrgv, llvmgv, runtime_lib);
}

// Emit a "PLT" entry that will be lazily initialized
Expand Down Expand Up @@ -169,7 +207,7 @@ static GlobalVariable *emit_plt_thunk(
fname);
BasicBlock *b0 = BasicBlock::Create(jl_LLVMContext, "top", plt);
IRBuilder<> irbuilder(b0);
Value *ptr = runtime_sym_lookup(emission_context, irbuilder, funcptype, f_lib, f_name, plt, libptrgv,
Value *ptr = runtime_sym_lookup(emission_context, irbuilder, NULL, funcptype, f_lib, NULL, f_name, plt, libptrgv,
llvmgv, runtime_lib);
StoreInst *store = irbuilder.CreateAlignedStore(irbuilder.CreateBitCast(ptr, T_pvoidfunc), got, Align(sizeof(void*)));
store->setAtomic(AtomicOrdering::Release);
Expand Down Expand Up @@ -475,6 +513,7 @@ typedef struct {
void (*fptr)(void); // if the argument is a constant pointer
const char *f_name; // if the symbol name is known
const char *f_lib; // if a library name is specified
jl_value_t *lib_expr; // expression to compute library path lazily
jl_value_t *gcroot;
} native_sym_arg_t;

Expand All @@ -488,6 +527,24 @@ static void interpret_symbol_arg(jl_codectx_t &ctx, native_sym_arg_t &out, jl_va

jl_value_t *ptr = static_eval(ctx, arg);
if (ptr == NULL) {
if (jl_is_expr(arg) && ((jl_expr_t*)arg)->head == call_sym && jl_expr_nargs(arg) == 3 &&
jl_is_globalref(jl_exprarg(arg,0)) && jl_globalref_mod(jl_exprarg(arg,0)) == jl_core_module &&
jl_globalref_name(jl_exprarg(arg,0)) == jl_symbol("tuple")) {
// attempt to interpret a non-constant 2-tuple expression as (func_name, lib_name()), where
// `lib_name()` will be executed when first used.
jl_value_t *name_val = static_eval(ctx, jl_exprarg(arg,1));
if (name_val && jl_is_symbol(name_val)) {
f_name = jl_symbol_name((jl_sym_t*)name_val);
out.lib_expr = jl_exprarg(arg, 2);
return;
}
else if (name_val && jl_is_string(name_val)) {
f_name = jl_string_data(name_val);
out.gcroot = name_val;
out.lib_expr = jl_exprarg(arg, 2);
return;
}
}
jl_cgval_t arg1 = emit_expr(ctx, arg);
jl_value_t *ptr_ty = arg1.typ;
if (!jl_is_cpointer_type(ptr_ty)) {
Expand Down Expand Up @@ -586,8 +643,11 @@ static jl_cgval_t emit_cglobal(jl_codectx_t &ctx, jl_value_t **args, size_t narg
jl_printf(JL_STDERR,"WARNING: literal address used in cglobal for %s; code cannot be statically compiled\n", sym.f_name);
}
else {
if (imaging_mode) {
res = runtime_sym_lookup(ctx, cast<PointerType>(T_pint8), sym.f_lib, sym.f_name, ctx.f);
if (sym.lib_expr) {
res = runtime_sym_lookup(ctx, cast<PointerType>(T_pint8), NULL, sym.lib_expr, sym.f_name, ctx.f);
}
else if (imaging_mode) {
res = runtime_sym_lookup(ctx, cast<PointerType>(T_pint8), sym.f_lib, NULL, sym.f_name, ctx.f);
res = ctx.builder.CreatePtrToInt(res, lrt);
}
else {
Expand All @@ -597,7 +657,7 @@ static jl_cgval_t emit_cglobal(jl_codectx_t &ctx, jl_value_t **args, size_t narg
if (!libsym || !jl_dlsym(libsym, sym.f_name, &symaddr, 0)) {
// Error mode, either the library or the symbol couldn't be find during compiletime.
// Fallback to a runtime symbol lookup.
res = runtime_sym_lookup(ctx, cast<PointerType>(T_pint8), sym.f_lib, sym.f_name, ctx.f);
res = runtime_sym_lookup(ctx, cast<PointerType>(T_pint8), sym.f_lib, NULL, sym.f_name, ctx.f);
res = ctx.builder.CreatePtrToInt(res, lrt);
} else {
// since we aren't saving this code, there's no sense in
Expand Down Expand Up @@ -1737,11 +1797,14 @@ jl_cgval_t function_sig_t::emit_a_ccall(
else {
assert(symarg.f_name != NULL);
PointerType *funcptype = PointerType::get(functype, 0);
if (imaging_mode) {
if (symarg.lib_expr) {
llvmf = runtime_sym_lookup(ctx, funcptype, NULL, symarg.lib_expr, symarg.f_name, ctx.f);
}
else if (imaging_mode) {
// vararg requires musttail,
// but musttail is incompatible with noreturn.
if (functype->isVarArg())
llvmf = runtime_sym_lookup(ctx, funcptype, symarg.f_lib, symarg.f_name, ctx.f);
llvmf = runtime_sym_lookup(ctx, funcptype, symarg.f_lib, NULL, symarg.f_name, ctx.f);
else
llvmf = emit_plt(ctx, functype, attributes, cc, symarg.f_lib, symarg.f_name);
}
Expand All @@ -1751,7 +1814,7 @@ jl_cgval_t function_sig_t::emit_a_ccall(
if (!libsym || !jl_dlsym(libsym, symarg.f_name, &symaddr, 0)) {
// either the library or the symbol could not be found, place a runtime
// lookup here instead.
llvmf = runtime_sym_lookup(ctx, funcptype, symarg.f_lib, symarg.f_name, ctx.f);
llvmf = runtime_sym_lookup(ctx, funcptype, symarg.f_lib, NULL, symarg.f_name, ctx.f);
} else {
// since we aren't saving this code, there's no sense in
// putting anything complicated here: just JIT the function address
Expand Down
6 changes: 6 additions & 0 deletions src/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -704,6 +704,12 @@ static const auto jldlsym_func = new JuliaFunction{
{T_pint8, T_pint8, PointerType::get(T_pint8, 0)}, false); },
nullptr,
};
static const auto jllazydlsym_func = new JuliaFunction{
"jl_lazy_load_and_lookup",
[](LLVMContext &C) { return FunctionType::get(T_pvoidfunc,
{T_prjlvalue, T_pint8}, false); },
nullptr,
};
static const auto jltypeassert_func = new JuliaFunction{
"jl_typeassert",
[](LLVMContext &C) { return FunctionType::get(T_void,
Expand Down
18 changes: 11 additions & 7 deletions src/julia-syntax.scm
Original file line number Diff line number Diff line change
Expand Up @@ -3906,11 +3906,9 @@ f(x) = yt(x)
(cond ((eq? (car e) 'foreigncall)
;; NOTE: 2nd to 5th arguments of ccall must be left in place
;; the 1st should be compiled if an atom.
(append (if (or (atom? (cadr e))
(let ((fptr (cadr e)))
(not (and (length> fptr 1)
(eq? (car fptr) 'call)
(equal? (cadr fptr) '(core tuple))))))
(append (if (let ((fptr (cadr e)))
(or (atom? fptr)
(not (tuple-call? fptr))))
(compile-args (list (cadr e)) break-labels)
(list (cadr e)))
(list-head (cddr e) 4)
Expand Down Expand Up @@ -4466,8 +4464,14 @@ f(x) = yt(x)
`(gotoifnot ,(renumber-stuff (cadr e)) ,(get label-table (caddr e))))
((eq? (car e) 'lambda)
(renumber-lambda e 'none 0))
(else (cons (car e)
(map renumber-stuff (cdr e))))))
(else
(let ((e (cons (car e)
(map renumber-stuff (cdr e)))))
(if (and (eq? (car e) 'foreigncall)
(tuple-call? (cadr e))
(expr-contains-p (lambda (x) (or (ssavalue? x) (slot? x))) (cadr e)))
(error "ccall function name and library expression cannot reference local variables"))
e))))
(let ((body (renumber-stuff (lam:body lam)))
(vi (lam:vinfo lam)))
(listify-lambda
Expand Down
1 change: 1 addition & 0 deletions src/julia_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -988,6 +988,7 @@ void *jl_get_library_(const char *f_lib, int throw_err) JL_NOTSAFEPOINT;
#define jl_get_library(f_lib) jl_get_library_(f_lib, 1)
JL_DLLEXPORT void *jl_load_and_lookup(const char *f_lib, const char *f_name,
void **hnd) JL_NOTSAFEPOINT;
JL_DLLEXPORT void *jl_lazy_load_and_lookup(jl_value_t *lib_val, const char *f_name);
JL_DLLEXPORT jl_value_t *jl_get_cfunction_trampoline(
jl_value_t *fobj, jl_datatype_t *result, htable_t *cache, jl_svec_t *fill,
void *(*init_trampoline)(void *tramp, void **nval),
Expand Down
17 changes: 17 additions & 0 deletions src/runtime_ccall.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,23 @@ void *jl_load_and_lookup(const char *f_lib, const char *f_name, void **hnd) JL_N
return ptr;
}

// jl_load_and_lookup, but with library computed at run time on first call
extern "C" JL_DLLEXPORT
void *jl_lazy_load_and_lookup(jl_value_t *lib_val, const char *f_name)
{
char *f_lib;

if (jl_is_symbol(lib_val))
f_lib = jl_symbol_name((jl_sym_t*)lib_val);
else if (jl_is_string(lib_val))
f_lib = jl_string_data(lib_val);
else
jl_type_error("ccall", (jl_value_t*)jl_symbol_type, lib_val);
void *ptr;
jl_dlsym(jl_get_library(f_lib), f_name, &ptr, 1);
return ptr;
}

// miscellany
std::string jl_get_cpu_name_llvm(void)
{
Expand Down
8 changes: 8 additions & 0 deletions test/ccall.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1701,3 +1701,11 @@ end
str = GC.@preserve buffer unsafe_string(Cwstring(pointer(buffer)))
@test str == "α+β=15"
end

# issue #36458
compute_lib_name() = "libcc" * "alltest"
ccall_lazy_lib_name(x) = ccall((:testUcharX, compute_lib_name()), Int32, (UInt8,), x % UInt8)
@test ccall_lazy_lib_name(0) == 0
@test ccall_lazy_lib_name(3) == 1
ccall_with_undefined_lib() = ccall((:time, xx_nOt_DeFiNeD_xx), Cint, (Ptr{Cvoid},), C_NULL)
@test_throws UndefVarError(:xx_nOt_DeFiNeD_xx) ccall_with_undefined_lib()
2 changes: 2 additions & 0 deletions test/syntax.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1656,6 +1656,8 @@ end
# #6080
@test Meta.lower(@__MODULE__, :(ccall(:a, Cvoid, (Cint,), &x))) == Expr(:error, "invalid syntax &x")

@test Meta.lower(@__MODULE__, :(f(x) = (y = x + 1; ccall((:a, y), Cvoid, ())))) == Expr(:error, "ccall function name and library expression cannot reference local variables")

@test_throws ParseError Meta.parse("x.'")
@test_throws ParseError Meta.parse("0.+1")

Expand Down

2 comments on commit 5cba88c

@nanosoldier
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Executing the daily benchmark build, I will reply here when finished:

@nanosoldier runbenchmarks(ALL, isdaily = true)

@nanosoldier
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Your benchmark job has completed - possible performance regressions were detected. A full report can be found here. cc @ararslan

Please sign in to comment.