Skip to content

Commit

Permalink
INLINE: remove unused interfaces after inlining
Browse files Browse the repository at this point in the history
  • Loading branch information
awnawab committed May 28, 2024
1 parent 2f78e0a commit 067b768
Show file tree
Hide file tree
Showing 2 changed files with 123 additions and 6 deletions.
25 changes: 19 additions & 6 deletions loki/transformations/inline.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from loki.batch import Transformation
from loki.ir import (
Import, Comment, Assignment, VariableDeclaration, CallStatement,
Transformer, FindNodes, pragmas_attached, is_loki_pragma
Transformer, FindNodes, pragmas_attached, is_loki_pragma, Interface
)
from loki.expression import (
symbols as sym, FindVariables, FindInlineCalls, FindLiterals,
Expand Down Expand Up @@ -194,8 +194,8 @@ def map_inline_call(self, expr, *args, **kwargs):

def resolve_sequence_association_for_inlined_calls(routine, inline_internals, inline_marked):
"""
Resolve sequence association in calls to all member procedures (if `inline_internals = True`)
or in calls to procedures that have been marked with an inline pragma (if `inline_marked = True`).
Resolve sequence association in calls to all member procedures (if `inline_internals = True`)
or in calls to procedures that have been marked with an inline pragma (if `inline_marked = True`).
If both `inline_internals` and `inline_marked` are `False`, no processing is done.
"""
call_map = {}
Expand All @@ -211,9 +211,9 @@ def resolve_sequence_association_for_inlined_calls(routine, inline_internals, in
# asked sequence assoc to happen with inlining, so source for routine should be
# found in calls to be inlined.
raise ValueError(
f"Cannot resolve sequence association for call to `{call.name}` " +
f"to be inlined in routine `{routine.name}`, because " +
f"the `CallStatement` referring to `{call.name}` does not contain " +
f"Cannot resolve sequence association for call to `{call.name}` " +
f"to be inlined in routine `{routine.name}`, because " +
f"the `CallStatement` referring to `{call.name}` does not contain " +
"the source code of the procedure. " +
"If running in batch processing mode, please recheck Scheduler configuration."
)
Expand Down Expand Up @@ -606,6 +606,19 @@ def inline_marked_subroutines(routine, allowed_aliases=None, adjust_imports=True
# Remove import if no further symbols used, otherwise clone with new symbols
import_map[impt] = impt.clone(symbols=new_symbols) if new_symbols else None

# Remove explicit interfaces of inlined routines
for intf in FindNodes(Interface).visit(routine.ir):
if not intf.spec:
_body = []
for b in intf.body:
s = getattr(b, 'procedure_symbol', None)
if not s or (s.name not in callees or s.name in not_inlined):
_body += [b,]
if _body:
import_map[intf] = intf.clone(body=as_tuple(_body))
else:
import_map[intf] = None

# Now move any callee imports we might need over to the caller
new_imports = set()
imported_module_map = CaseInsensitiveDict((im.module, im) for im in routine.imports)
Expand Down
104 changes: 104 additions & 0 deletions loki/transformations/tests/test_inline.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from loki.expression import symbols as sym
from loki.frontend import available_frontends, OMNI, OFP
from loki.ir import nodes as ir, FindNodes
from loki.tools import flatten

from loki.transformations.inline import (
inline_elemental_functions, inline_constant_parameters,
Expand Down Expand Up @@ -786,6 +787,109 @@ def test_inline_marked_subroutines(frontend, adjust_imports):
assert imports[0].symbols == ('add_one', 'add_a_to_b')


@pytest.mark.parametrize('frontend', available_frontends())
def test_inline_marked_subroutines_with_interfaces(frontend):
""" Test inlining of subroutines with explicit interfaces via marker pragmas. """

fcode_driver = """
subroutine test_pragma_inline(a, b)
implicit none
interface
subroutine add_a_to_b(a, b, n)
real(kind=8), intent(inout) :: a(:), b(:)
integer, intent(in) :: n
end subroutine add_a_to_b
subroutine add_one(a)
real(kind=8), intent(inout) :: a
end subroutine add_one
end interface
interface
subroutine add_two(a)
real(kind=8), intent(inout) :: a
end subroutine add_two
end interface
real(kind=8), intent(inout) :: a(3), b(3)
integer, parameter :: n = 3
integer :: i
do i=1, n
!$loki inline
call add_one(a(i))
end do
!$loki inline
call add_a_to_b(a(:), b(:), 3)
do i=1, n
call add_one(b(i))
!$loki inline
call add_two(b(i))
end do
end subroutine test_pragma_inline
"""

fcode_module = """
module util_mod
implicit none
contains
subroutine add_one(a)
real(kind=8), intent(inout) :: a
a = a + 1
end subroutine add_one
subroutine add_two(a)
real(kind=8), intent(inout) :: a
a = a + 2
end subroutine add_two
subroutine add_a_to_b(a, b, n)
real(kind=8), intent(inout) :: a(:), b(:)
integer, intent(in) :: n
integer :: i
do i = 1, n
a(i) = a(i) + b(i)
end do
end subroutine add_a_to_b
end module util_mod
"""

module = Module.from_source(fcode_module, frontend=frontend)
driver = Subroutine.from_source(fcode_driver, frontend=frontend)
driver.enrich(module.subroutines)

calls = FindNodes(ir.CallStatement).visit(driver.body)
assert calls[0].routine == module['add_one']
assert calls[1].routine == module['add_a_to_b']
assert calls[2].routine == module['add_one']
assert calls[3].routine == module['add_two']

inline_marked_subroutines(routine=driver, allowed_aliases=('I',))

# Check inlined loops and assignments
assert len(FindNodes(ir.Loop).visit(driver.body)) == 3
assign = FindNodes(ir.Assignment).visit(driver.body)
assert len(assign) == 3
assert assign[0].lhs == 'a(i)' and assign[0].rhs == 'a(i) + 1'
assert assign[1].lhs == 'a(i)' and assign[1].rhs == 'a(i) + b(i)'
assert assign[2].lhs == 'b(i)' and assign[2].rhs == 'b(i) + 2'

# Check that the last call is left untouched
calls = FindNodes(ir.CallStatement).visit(driver.body)
assert len(calls) == 1
assert calls[0].routine.name == 'add_one'
assert calls[0].arguments == ('b(i)',)

intfs = FindNodes(ir.Interface).visit(driver.spec)
assert len(intfs) == 1
assert intfs[0].symbols == ('add_one',)


@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('adjust_imports', [True, False])
def test_inline_marked_routine_with_optionals(frontend, adjust_imports):
Expand Down

0 comments on commit 067b768

Please sign in to comment.