Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inline: Fix rescoping of intrinsic procedure symbols in elementals #445

Merged
merged 1 commit into from
Nov 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion loki/expression/mappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -816,8 +816,12 @@ def map_variable_symbol(self, expr, *args, **kwargs):
return map_fn(new_expr, *args, **kwargs)

map_deferred_type_symbol = map_variable_symbol
map_procedure_symbol = map_variable_symbol

def map_procedure_symbol(self, expr, *args, **kwargs):
if expr.type and expr.type.is_intrinsic:
# Always rescope intrinsics to the closest scope
return expr.clone(scope=kwargs['scope'])
return self.map_variable_symbol(expr, *args, **kwargs)

class DetachScopesMapper(LokiIdentityMapper):
"""
Expand Down
5 changes: 4 additions & 1 deletion loki/transformations/inline/procedures.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from loki.ir import (
Import, Comment, VariableDeclaration, CallStatement, Transformer,
FindNodes, FindVariables, FindInlineCalls, SubstituteExpressions,
pragmas_attached, is_loki_pragma, Interface, Pragma
pragmas_attached, is_loki_pragma, Interface, Pragma, AttachScopes
)
from loki.expression import symbols as sym
from loki.types import BasicType
Expand Down Expand Up @@ -162,6 +162,9 @@ def _map_unbound_dims(var, val):
if is_loki_pragma(pragma, starts_with='routine')}
).visit(callee_body)

# Ensure all symbols are rescoped to the caller
AttachScopes().visit(callee_body, scope=caller)

# Inline substituted body within a pair of marker comments
comment = Comment(f'! [Loki] inlined child subroutine: {callee.name}')
c_line = Comment('! =========================================')
Expand Down
44 changes: 44 additions & 0 deletions loki/transformations/inline/tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from loki import Module, Subroutine
from loki.build import jit_compile_lib, Builder, Obj
from loki.expression import symbols as sym
from loki.frontend import available_frontends, OMNI
from loki.ir import (
nodes as ir, FindNodes, FindVariables, FindInlineCalls
Expand Down Expand Up @@ -405,3 +406,46 @@ def test_inline_statement_functions_inline_call(frontend, provide_myfunc, tmp_pa
# myfunc not inlined
assert assignments[0].rhs == "arr + arr + 1.0 + myfunc(arr) + myfunc(arr)"
assert assignments[1].rhs == "3.0 + 1.0 + myfunc(3.0) + val + 1.0 + myfunc(val)"


@pytest.mark.parametrize('frontend', available_frontends())
def test_inline_elemental_functions_intrinsic_procs(frontend):
fcode = """
subroutine test_inline_elementals(a)
implicit none
integer, parameter :: jprb = 8
real(kind=jprb), intent(inout) :: a

a = fminj(0.5, a)
contains
pure elemental function fminj(x,y) result(m)
real(kind=jprb), intent(in) :: x, y
real(kind=jprb) :: m

m = y - 0.5_jprb*(abs(x-y)-(x-y))
end function fminj
end subroutine test_inline_elementals
"""
routine = Subroutine.from_source(fcode, frontend=frontend)

assigns = FindNodes(ir.Assignment).visit(routine.body)
assert len(assigns) == 1
assert isinstance(assigns[0].rhs.function, sym.ProcedureSymbol)
assert assigns[0].rhs.function.type.dtype.procedure == routine.members[0]

# Ensure we have an intrinsic in the internal elemental function
inline_calls = tuple(FindInlineCalls().visit(routine.members[0].body))
assert len(inline_calls) == 1
assert inline_calls[0].function.type.is_intrinsic
assert inline_calls[0].function.scope == routine.members[0]

inline_elemental_functions(routine)

assigns = FindNodes(ir.Assignment).visit(routine.body)
assert len(assigns) == 2

# Ensure that the intrinsic function has been rescoped
inline_calls = tuple(FindInlineCalls().visit(assigns[0]))
assert len(inline_calls) == 1
assert inline_calls[0].function.type.is_intrinsic
assert inline_calls[0].function.scope == routine
Loading