Skip to content

Commit

Permalink
Inline: Fix rescoping of intrinsic procedure symbols in elementals
Browse files Browse the repository at this point in the history
This caused problems by a race condition, where the elemental scope
that was the original scope of intrisincs could be re-build, making
the scope weakref invalid before the final catch-all rescoping.

To fix this, I'm explicitly foxing the function body to be rescoped,
before it gets inserted. This was problematic, as intrisic procedure
symbols were not updated correctly, so I enforce indiscriminant
rescoping for intrisic procedure symbols to the given, closest scope
in `AttachScopesMapper`.
  • Loading branch information
mlange05 committed Nov 22, 2024
1 parent 72d35fb commit af96233
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 2 deletions.
6 changes: 5 additions & 1 deletion loki/expression/mappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -816,8 +816,12 @@ def map_variable_symbol(self, expr, *args, **kwargs):
return map_fn(new_expr, *args, **kwargs)

map_deferred_type_symbol = map_variable_symbol
map_procedure_symbol = map_variable_symbol

def map_procedure_symbol(self, expr, *args, **kwargs):
if expr.type.is_intrinsic:
# Always rescope intrinsics to the closest scope
return expr.clone(scope=kwargs['scope'])
return self.map_variable_symbol(expr, *args, **kwargs)

class DetachScopesMapper(LokiIdentityMapper):
"""
Expand Down
5 changes: 4 additions & 1 deletion loki/transformations/inline/procedures.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from loki.ir import (
Import, Comment, VariableDeclaration, CallStatement, Transformer,
FindNodes, FindVariables, FindInlineCalls, SubstituteExpressions,
pragmas_attached, is_loki_pragma, Interface, Pragma
pragmas_attached, is_loki_pragma, Interface, Pragma, AttachScopes
)
from loki.expression import symbols as sym
from loki.types import BasicType
Expand Down Expand Up @@ -162,6 +162,9 @@ def _map_unbound_dims(var, val):
if is_loki_pragma(pragma, starts_with='routine')}
).visit(callee_body)

# Ensure all symbols are rescoped to the caller
AttachScopes().visit(callee_body, scope=caller)

# Inline substituted body within a pair of marker comments
comment = Comment(f'! [Loki] inlined child subroutine: {callee.name}')
c_line = Comment('! =========================================')
Expand Down
44 changes: 44 additions & 0 deletions loki/transformations/inline/tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from loki import Module, Subroutine
from loki.build import jit_compile_lib, Builder, Obj
from loki.expression import symbols as sym
from loki.frontend import available_frontends, OMNI
from loki.ir import (
nodes as ir, FindNodes, FindVariables, FindInlineCalls
Expand Down Expand Up @@ -405,3 +406,46 @@ def test_inline_statement_functions_inline_call(frontend, provide_myfunc, tmp_pa
# myfunc not inlined
assert assignments[0].rhs == "arr + arr + 1.0 + myfunc(arr) + myfunc(arr)"
assert assignments[1].rhs == "3.0 + 1.0 + myfunc(3.0) + val + 1.0 + myfunc(val)"


@pytest.mark.parametrize('frontend', available_frontends())
def test_inline_elemental_functions_intrinsic_procs(frontend):
fcode = """
subroutine test_inline_elementals(a)
implicit none
integer, parameter :: jprb = 8
real(kind=jprb), intent(inout) :: a
a = fminj(0.5, a)
contains
pure elemental function fminj(x,y) result(m)
real(kind=jprb), intent(in) :: x, y
real(kind=jprb) :: m
m = y - 0.5_jprb*(abs(x-y)-(x-y))
end function fminj
end subroutine test_inline_elementals
"""
routine = Subroutine.from_source(fcode, frontend=frontend)

assigns = FindNodes(ir.Assignment).visit(routine.body)
assert len(assigns) == 1
assert isinstance(assigns[0].rhs.function, sym.ProcedureSymbol)
assert assigns[0].rhs.function.type.dtype.procedure == routine.members[0]

# Ensure we have an intrinsic in the internal elemental function
inline_calls = tuple(FindInlineCalls().visit(routine.members[0].body))
assert len(inline_calls) == 1
assert inline_calls[0].function.type.is_intrinsic
assert inline_calls[0].function.scope == routine.members[0]

inline_elemental_functions(routine)

assigns = FindNodes(ir.Assignment).visit(routine.body)
assert len(assigns) == 2

# Ensure that the intrinsic function has been rescoped
inline_calls = tuple(FindInlineCalls().visit(assigns[0]))
assert len(inline_calls) == 1
assert inline_calls[0].function.type.is_intrinsic
assert inline_calls[0].function.scope == routine

0 comments on commit af96233

Please sign in to comment.