Skip to content

Commit

Permalink
codegen: make more judicious use of global rooting (#56625)
Browse files Browse the repository at this point in the history
Only temporarily root objects during codegen, so that it is the
responsibility of the caller to ensure the values live (including write
barriers and old generations) only as long as necessary for correct
execution, and not preserve values that never make it into the IR.
  • Loading branch information
vtjnash authored Nov 22, 2024
1 parent 5fab51a commit 7354be3
Show file tree
Hide file tree
Showing 7 changed files with 204 additions and 49 deletions.
123 changes: 112 additions & 11 deletions src/aotcompile.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -320,14 +320,106 @@ static jl_code_instance_t *jl_ci_cache_lookup(jl_method_instance_t *mi, size_t w
return codeinst;
}

namespace { // file-local namespace
class egal_set {
public:
jl_genericmemory_t *list = (jl_genericmemory_t*)jl_an_empty_memory_any;
jl_genericmemory_t *keyset = (jl_genericmemory_t*)jl_an_empty_memory_any;
egal_set(egal_set&) = delete;
egal_set(egal_set&&) = delete;
egal_set() = default;
void insert(jl_value_t *val)
{
jl_value_t *rval = jl_idset_get(list, keyset, val);
if (rval == NULL) {
ssize_t idx;
list = jl_idset_put_key(list, val, &idx);
keyset = jl_idset_put_idx(list, keyset, idx);
}
}
jl_value_t *get(jl_value_t *val)
{
return jl_idset_get(list, keyset, val);
}
};
}
using ::egal_set;
typedef DenseMap<jl_code_instance_t*, std::pair<orc::ThreadSafeModule, jl_llvm_functions_t>> jl_compiled_functions_t;
static void compile_workqueue(jl_codegen_params_t &params, CompilationPolicy policy, jl_compiled_functions_t &compiled_functions)

static void record_method_roots(egal_set &method_roots, jl_method_instance_t *mi)
{
jl_method_t *m = mi->def.method;
if (!jl_is_method(m))
return;
// the method might have a root for this already; use it if so
JL_LOCK(&m->writelock);
if (m->roots) {
size_t j, len = jl_array_dim0(m->roots);
for (j = 0; j < len; j++) {
jl_value_t *v = jl_array_ptr_ref(m->roots, j);
if (jl_is_globally_rooted(v))
continue;
method_roots.insert(v);
}
}
JL_UNLOCK(&m->writelock);
}

static void aot_optimize_roots(jl_codegen_params_t &params, egal_set &method_roots, jl_compiled_functions_t &compiled_functions)
{
for (size_t i = 0; i < jl_array_dim0(params.temporary_roots); i++) {
jl_value_t *val = jl_array_ptr_ref(params.temporary_roots, i);
auto ref = params.global_targets.find((void*)val);
if (ref == params.global_targets.end())
continue;
auto get_global_root = [val, &method_roots]() {
if (jl_is_globally_rooted(val))
return val;
jl_value_t *mval = method_roots.get(val);
if (mval)
return mval;
return jl_as_global_root(val, 1);
};
jl_value_t *mval = get_global_root();
if (mval != val) {
GlobalVariable *GV = ref->second;
params.global_targets.erase(ref);
auto mref = params.global_targets.find((void*)mval);
if (mref != params.global_targets.end()) {
// replace ref with mref in all Modules
std::string OldName(GV->getName());
StringRef NewName(mref->second->getName());
for (auto &def : compiled_functions) {
orc::ThreadSafeModule &TSM = std::get<0>(def.second);
Module &M = *TSM.getModuleUnlocked();
if (GlobalValue *GV2 = M.getNamedValue(OldName)) {
if (GV2 == GV)
GV = nullptr;
// either replace or rename the old value to use the other equivalent name
if (GlobalValue *GV3 = M.getNamedValue(NewName)) {
GV2->replaceAllUsesWith(GV3);
GV2->eraseFromParent();
}
else {
GV2->setName(NewName);
}
}
}
assert(GV == nullptr);
}
else {
params.global_targets[(void*)mval] = GV;
}
}
}
}

static void compile_workqueue(jl_codegen_params_t &params, egal_set &method_roots, CompilationPolicy policy, jl_compiled_functions_t &compiled_functions)
{
decltype(params.workqueue) workqueue;
std::swap(params.workqueue, workqueue);
jl_code_info_t *src = NULL;
jl_code_instance_t *codeinst = NULL;
JL_GC_PUSH2(&src, &codeinst);
JL_GC_PUSH1(&codeinst);
assert(!params.cache);
while (!workqueue.empty()) {
auto it = workqueue.pop_back_val();
Expand All @@ -352,6 +444,7 @@ static void compile_workqueue(jl_codegen_params_t &params, CompilationPolicy pol
jl_create_ts_module(name_from_method_instance(codeinst->def),
params.tsctx, params.DL, params.TargetTriple);
auto decls = jl_emit_codeinst(result_m, codeinst, NULL, params);
record_method_roots(method_roots, codeinst->def);
if (result_m)
it = compiled_functions.insert(std::make_pair(codeinst, std::make_pair(std::move(result_m), std::move(decls)))).first;
}
Expand Down Expand Up @@ -432,7 +525,6 @@ static void compile_workqueue(jl_codegen_params_t &params, CompilationPolicy pol
JL_GC_POP();
}


// takes the running content that has collected in the shadow module and dump it to disk
// this builds the object file portion of the sysimage files for fast startup, and can
// also be used be extern consumers like GPUCompiler.jl to obtain a module containing
Expand All @@ -454,8 +546,6 @@ void *jl_create_native_impl(jl_array_t *methods, LLVMOrcThreadSafeModuleRef llvm
CompilationPolicy policy = (CompilationPolicy) _policy;
bool imaging = imaging_default() || _imaging_mode == 1;
jl_method_instance_t *mi = NULL;
jl_code_info_t *src = NULL;
JL_GC_PUSH1(&src);
auto ct = jl_current_task;
bool timed = (ct->reentrant_timing & 1) == 0;
if (timed)
Expand All @@ -479,11 +569,14 @@ void *jl_create_native_impl(jl_array_t *methods, LLVMOrcThreadSafeModuleRef llvm
auto target_info = clone.withModuleDo([&](Module &M) {
return std::make_pair(M.getDataLayout(), Triple(M.getTargetTriple()));
});
egal_set method_roots;
jl_codegen_params_t params(ctxt, std::move(target_info.first), std::move(target_info.second));
params.params = cgparams;
params.imaging_mode = imaging;
params.debug_level = cgparams->debug_info_level;
params.external_linkage = _external_linkage;
params.temporary_roots = jl_alloc_array_1d(jl_array_any_type, 0);
JL_GC_PUSH3(&params.temporary_roots, &method_roots.list, &method_roots.keyset);
size_t compile_for[] = { jl_typeinf_world, _world };
int worlds = 0;
if (jl_options.trim != JL_TRIM_NO)
Expand All @@ -508,13 +601,13 @@ void *jl_create_native_impl(jl_array_t *methods, LLVMOrcThreadSafeModuleRef llvm
continue;
}
mi = (jl_method_instance_t*)item;
src = NULL;
// if this method is generally visible to the current compilation world,
// and this is either the primary world, or not applicable in the primary world
// then we want to compile and emit this
if (jl_atomic_load_relaxed(&mi->def.method->primary_world) <= this_world && this_world <= jl_atomic_load_relaxed(&mi->def.method->deleted_world)) {
// find and prepare the source code to compile
jl_code_instance_t *codeinst = jl_ci_cache_lookup(mi, this_world, lookup);
JL_GC_PROMISE_ROOTED(codeinst);
if (jl_options.trim != JL_TRIM_NO && !codeinst) {
// If we're building a small image, we need to compile everything
// to ensure that we have all the information we need.
Expand All @@ -529,11 +622,12 @@ void *jl_create_native_impl(jl_array_t *methods, LLVMOrcThreadSafeModuleRef llvm
data->jl_fvar_map[codeinst] = std::make_tuple((uint32_t)-3, (uint32_t)-3);
}
else {
JL_GC_PROMISE_ROOTED(codeinst->rettype);
orc::ThreadSafeModule result_m = jl_create_ts_module(name_from_method_instance(codeinst->def),
params.tsctx, clone.getModuleUnlocked()->getDataLayout(),
Triple(clone.getModuleUnlocked()->getTargetTriple()));
jl_llvm_functions_t decls = jl_emit_codeinst(result_m, codeinst, NULL, params);
JL_GC_PROMISE_ROOTED(codeinst->def); // analyzer seems confused
record_method_roots(method_roots, codeinst->def);
if (result_m)
compiled_functions[codeinst] = {std::move(result_m), std::move(decls)};
else if (jl_options.trim != JL_TRIM_NO) {
Expand All @@ -555,9 +649,11 @@ void *jl_create_native_impl(jl_array_t *methods, LLVMOrcThreadSafeModuleRef llvm
}
}
}
JL_GC_POP();
// finally, make sure all referenced methods also get compiled or fixed up
compile_workqueue(params, policy, compiled_functions);
compile_workqueue(params, method_roots, policy, compiled_functions);
aot_optimize_roots(params, method_roots, compiled_functions);
params.temporary_roots = nullptr;
JL_GC_POP();

// process the globals array, before jl_merge_module destroys them
SmallVector<std::string, 0> gvars(params.global_targets.size());
Expand Down Expand Up @@ -2161,7 +2257,11 @@ void jl_get_llvmf_defn_impl(jl_llvmf_dump_t* dump, jl_method_instance_t *mi, jl_
// This would also be nice, but it seems to cause OOMs on the windows32 builder
// To get correct names in the IR this needs to be at least 2
output.debug_level = params.debug_info_level;
output.temporary_roots = jl_alloc_array_1d(jl_array_any_type, 0);
JL_GC_PUSH1(&output.temporary_roots);
auto decls = jl_emit_code(m, mi, src, output);
output.temporary_roots = nullptr;
JL_GC_POP(); // GC the global_targets array contents now since reflection doesn't need it

Function *F = NULL;
if (m) {
Expand All @@ -2171,7 +2271,8 @@ void jl_get_llvmf_defn_impl(jl_llvmf_dump_t* dump, jl_method_instance_t *mi, jl_
for (auto &global : output.global_targets) {
if (jl_options.image_codegen) {
global.second->setLinkage(GlobalValue::ExternalLinkage);
} else {
}
else {
auto p = literal_static_pointer_val(global.first, global.second->getValueType());
#if JL_LLVM_VERSION >= 170000
Type *elty = PointerType::get(output.getContext(), 0);
Expand Down
9 changes: 6 additions & 3 deletions src/ccall.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1548,7 +1548,7 @@ static jl_cgval_t emit_ccall(jl_codectx_t &ctx, jl_value_t **args, size_t nargs)
return jl_cgval_t();
}
if (rt != args[2] && rt != (jl_value_t*)jl_any_type)
rt = jl_ensure_rooted(ctx, rt);
jl_temporary_root(ctx, rt);
function_sig_t sig("ccall", lrt, rt, retboxed,
(jl_svec_t*)at, unionall, nreqargs,
cc, llvmcall, &ctx.emission_context);
Expand Down Expand Up @@ -2036,8 +2036,11 @@ jl_cgval_t function_sig_t::emit_a_ccall(
if (ctx.spvals_ptr == NULL && !toboxed && unionall_env && jl_has_typevar_from_unionall(jargty, unionall_env) &&
jl_svec_len(ctx.linfo->sparam_vals) > 0) {
jargty_in_env = jl_instantiate_type_in_env(jargty_in_env, unionall_env, jl_svec_data(ctx.linfo->sparam_vals));
if (jargty_in_env != jargty)
jargty_in_env = jl_ensure_rooted(ctx, jargty_in_env);
if (jargty_in_env != jargty) {
JL_GC_PUSH1(&jargty_in_env);
jl_temporary_root(ctx, jargty_in_env);
JL_GC_POP();
}
}

Value *v;
Expand Down
12 changes: 7 additions & 5 deletions src/cgutils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,8 @@ static llvm::SmallVector<Value*,0> get_gc_roots_for(jl_codectx_t &ctx, const jl_

// --- emitting pointers directly into code ---


static void jl_temporary_root(jl_codegen_params_t &ctx, jl_value_t *val);
static void jl_temporary_root(jl_codectx_t &ctx, jl_value_t *val);
static inline Constant *literal_static_pointer_val(const void *p, Type *T);

static Constant *julia_pgv(jl_codectx_t &ctx, const char *cname, void *addr)
Expand Down Expand Up @@ -777,7 +778,8 @@ static Type *_julia_struct_to_llvm(jl_codegen_params_t *ctx, LLVMContext &ctxt,
if (ntypes == 0 || jl_datatype_nbits(jst) == 0)
return getVoidTy(ctxt);
Type *_struct_decl = NULL;
// TODO: we should probably make a temporary root for `jst` somewhere
if (ctx)
jl_temporary_root(*ctx, jt);
// don't use pre-filled struct_decl for llvmcall (f16, etc. may be different)
Type *&struct_decl = (ctx && !llvmcall ? ctx->llvmtypes[jst] : _struct_decl);
if (struct_decl)
Expand Down Expand Up @@ -3506,8 +3508,6 @@ static Value *call_with_attrs(jl_codectx_t &ctx, JuliaFunction<TypeFn_t> *intr,
return Call;
}

static jl_value_t *jl_ensure_rooted(jl_codectx_t &ctx, jl_value_t *val);

static Value *as_value(jl_codectx_t &ctx, Type *to, const jl_cgval_t &v)
{
assert(!v.isboxed);
Expand Down Expand Up @@ -3540,7 +3540,9 @@ static Value *_boxed_special(jl_codectx_t &ctx, const jl_cgval_t &vinfo, Type *t
if (Constant *c = dyn_cast<Constant>(vinfo.V)) {
jl_value_t *s = static_constant_instance(jl_Module->getDataLayout(), c, jt);
if (s) {
s = jl_ensure_rooted(ctx, s);
JL_GC_PUSH1(&s);
jl_temporary_root(ctx, s);
JL_GC_POP();
return track_pjlvalue(ctx, literal_pointer_val(ctx, s));
}
}
Expand Down
46 changes: 19 additions & 27 deletions src/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3410,27 +3410,20 @@ static void simple_use_analysis(jl_codectx_t &ctx, jl_value_t *expr)

// ---- Get Element Pointer (GEP) instructions within the GC frame ----

static jl_value_t *jl_ensure_rooted(jl_codectx_t &ctx, jl_value_t *val)
{
if (jl_is_globally_rooted(val))
return val;
jl_method_t *m = ctx.linfo->def.method;
if (!jl_options.strip_ir && jl_is_method(m)) {
// the method might have a root for this already; use it if so
JL_LOCK(&m->writelock);
if (m->roots) {
size_t i, len = jl_array_dim0(m->roots);
for (i = 0; i < len; i++) {
jl_value_t *mval = jl_array_ptr_ref(m->roots, i);
if (mval == val || jl_egal(mval, val)) {
JL_UNLOCK(&m->writelock);
return mval;
}
}
static void jl_temporary_root(jl_codegen_params_t &ctx, jl_value_t *val)
{
if (!jl_is_globally_rooted(val)) {
jl_array_t *roots = ctx.temporary_roots;
for (size_t i = 0; i < jl_array_dim0(roots); i++) {
if (jl_array_ptr_ref(roots, i) == val)
return;
}
JL_UNLOCK(&m->writelock);
jl_array_ptr_1d_push(roots, val);
}
return jl_as_global_root(val, 1);
}
static void jl_temporary_root(jl_codectx_t &ctx, jl_value_t *val)
{
jl_temporary_root(ctx.emission_context, val);
}

// --- generating function calls ---
Expand Down Expand Up @@ -5060,7 +5053,7 @@ static bool emit_builtin_call(jl_codectx_t &ctx, jl_cgval_t *ret, jl_value_t *f,
jl_value_t *ty = static_apply_type(ctx, argv, nargs + 1);
if (ty != NULL) {
JL_GC_PUSH1(&ty);
ty = jl_ensure_rooted(ctx, ty);
jl_temporary_root(ctx, ty);
JL_GC_POP();
*ret = mark_julia_const(ctx, ty);
return true;
Expand Down Expand Up @@ -6785,7 +6778,7 @@ static jl_cgval_t emit_expr(jl_codectx_t &ctx, jl_value_t *expr, ssize_t ssaidx_
val = jl_fieldref_noalloc(expr, 0);
// Toplevel exprs are rooted but because codegen assumes this is constant, it removes the write barriers for this code.
// This means we have to globally root the value here. (The other option would be to change how we optimize toplevel code)
val = jl_ensure_rooted(ctx, val);
jl_temporary_root(ctx, val);
return mark_julia_const(ctx, val);
}

Expand Down Expand Up @@ -7905,7 +7898,7 @@ static jl_cgval_t emit_cfunction(jl_codectx_t &ctx, jl_value_t *output_type, con
return jl_cgval_t();
}
if (rt != declrt && rt != (jl_value_t*)jl_any_type)
rt = jl_ensure_rooted(ctx, rt);
jl_temporary_root(ctx, rt);

function_sig_t sig("cfunction", lrt, rt, retboxed, argt, unionall_env, false, CallingConv::C, false, &ctx.emission_context);
assert(sig.fargt.size() + sig.sret == sig.fargt_sig.size());
Expand Down Expand Up @@ -7975,12 +7968,12 @@ static jl_cgval_t emit_cfunction(jl_codectx_t &ctx, jl_value_t *output_type, con
if (closure_types) {
assert(ctx.spvals_ptr);
size_t n = jl_array_nrows(closure_types);
jl_svec_t *fill_i = jl_alloc_svec_uninit(n);
fill = jl_alloc_svec_uninit(n);
for (size_t i = 0; i < n; i++) {
jl_svecset(fill_i, i, jl_array_ptr_ref(closure_types, i));
jl_svecset(fill, i, jl_array_ptr_ref(closure_types, i));
}
JL_GC_PUSH1(&fill_i);
fill = (jl_svec_t*)jl_ensure_rooted(ctx, (jl_value_t*)fill_i);
JL_GC_PUSH1(&fill);
jl_temporary_root(ctx, (jl_value_t*)fill);
JL_GC_POP();
}
Type *T_htable = ArrayType::get(ctx.types().T_size, sizeof(htable_t) / sizeof(void*));
Expand Down Expand Up @@ -10106,7 +10099,6 @@ static jl_llvm_functions_t jl_emit_oc_wrapper(orc::ThreadSafeModule &m, jl_codeg
return declarations;
}


jl_llvm_functions_t jl_emit_codeinst(
orc::ThreadSafeModule &m,
jl_code_instance_t *codeinst,
Expand Down
Loading

0 comments on commit 7354be3

Please sign in to comment.