Skip to content

Commit

Permalink
PR fixes to hoist/inline trafos
Browse files Browse the repository at this point in the history
  • Loading branch information
awnawab committed Jun 12, 2024
1 parent 9836069 commit 78d9f45
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 8 deletions.
6 changes: 4 additions & 2 deletions loki/transformations/hoist_variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,9 @@ def transform_subroutine(self, routine, **kwargs):
variables = self.find_variables(routine)
item.trafo_data[self._key]["to_hoist"] = variables
dims = flatten([getattr(v, 'shape', []) for v in variables])
import_map = routine.import_map
item.trafo_data[self._key]["imported_sizes"] = [(d.type.module, d) for d in dims
if str(d) in routine.import_map]
if str(d) in import_map]
item.trafo_data[self._key]["hoist_variables"] = [var.clone(name=f'{routine.name}_{var.name}')
for var in variables]
else:
Expand Down Expand Up @@ -281,8 +282,9 @@ def transform_subroutine(self, routine, **kwargs):

# Add imports used to define hoisted
missing_imports_map = defaultdict(set)
import_map = routine.import_map
for module, var in item.trafo_data[self._key]["imported_sizes"]:
if not var.name in routine.import_map:
if not var.name in import_map:
missing_imports_map[module] |= {var}

if missing_imports_map:
Expand Down
11 changes: 5 additions & 6 deletions loki/transformations/inline.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,9 +199,9 @@ def map_inline_call(self, expr, *args, **kwargs):

def resolve_sequence_association_for_inlined_calls(routine, inline_internals, inline_marked):
"""
Resolve sequence association in calls to all member procedures (if `inline_internals = True`)
or in calls to procedures that have been marked with an inline pragma (if `inline_marked = True`).
If both `inline_internals` and `inline_marked` are `False`, no processing is done.
Resolve sequence association in calls to all member procedures (if ``inline_internals = True``)
or in calls to procedures that have been marked with an inline pragma (if ``inline_marked = True``).
If both ``inline_internals`` and ``inline_marked`` are ``False``, no processing is done.
"""
call_map = {}
with pragmas_attached(routine, node_type=CallStatement):
Expand Down Expand Up @@ -636,10 +636,9 @@ def inline_marked_subroutines(routine, allowed_aliases=None, adjust_imports=True
intf_symbols = routine.interface_symbols
for callee in call_sets.keys():
for intf in callee.interfaces:
for b in intf.body:
s = getattr(b, 'procedure_symbol', None)
for s in intf.symbols:
if not s in intf_symbols:
new_intfs += [b,]
new_intfs += [s.type.dtype.procedure,]

if new_intfs:
routine.spec.append(Interface(body=as_tuple(new_intfs)))
Expand Down

0 comments on commit 78d9f45

Please sign in to comment.