Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CMake plan and enrichment bugfix #441

Merged
merged 7 commits into from
Nov 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions loki/batch/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -635,8 +635,15 @@ def write_cmake_plan(self, filepath, mode, buildpath, rootpath):
sources_to_transform = []

# Filter the SGraph to get a pure call-tree
item_filter = None if self.config.enable_imports else ProcedureItem
for item in SFilter(self.sgraph, item_filter=item_filter):
item_filter = ProcedureItem
if self.config.enable_imports:
item_filter = as_tuple(item_filter) + (ModuleItem,)
graph = self.sgraph.as_filegraph(
self.item_factory, self.config, item_filter=item_filter,
exclude_ignored=True
)
traversal = SFilter(graph, reverse=False, include_external=False)
for item in traversal:
if item.is_ignored:
continue

Expand Down
15 changes: 13 additions & 2 deletions loki/batch/sgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,10 +223,21 @@ def _populate_filegraph(self, sgraph, item_factory, config=None, item_filter=Non
if file_item not in self._graph:
self.add_node(file_item)

# Update the "is_ignored" attribute for file items
# Update the "is_ignored" and "replicate" attributes for file items
for items in file_item_2_item_map.values():
file_item = item_2_file_item_map[items[0]]
is_ignored = all(item.is_ignored for item in items)
item_2_file_item_map[items[0]].config['is_ignored'] = is_ignored
file_item.config['is_ignored'] = is_ignored

replicate = any(item.replicate for item in items)
if replicate:
non_replicate_items = [item for item in items if not item.replicate]
if non_replicate_items:
warning((
f'File {file_item.name} will be replicated but contains items '
f'that are marked as non-replicated: {", ".join(item.name for item in non_replicate_items)}'
))
file_item.config['replicate'] = replicate

# Insert edges to the file items corresponding to the successors of the items
for item in SFilter(sgraph, item_filter, exclude_ignored=exclude_ignored):
Expand Down
104 changes: 103 additions & 1 deletion loki/batch/tests/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,9 @@
from loki.frontend import (
available_frontends, OMNI, FP, REGEX, HAVE_FP, HAVE_OMNI
)
from loki.ir import nodes as ir, FindNodes, FindInlineCalls
from loki.ir import (
nodes as ir, FindNodes, FindInlineCalls, FindVariables
)
from loki.transformations import (
DependencyTransformation, ModuleWrapTransformation
)
Expand Down Expand Up @@ -2920,3 +2922,103 @@ def test_pipeline_config_compose(config):
assert pipeline.transformations[2].directive == 'openacc'
assert pipeline.transformations[3].trim_vector_sections is True
assert pipeline.transformations[7].replace_ignore_items is True

@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('enable_imports', [False, True])
@pytest.mark.parametrize('import_level', ['module', 'subroutine'])
def test_scheduler_indirect_import(frontend, tmp_path, enable_imports, import_level):
fcode_mod_a = """
module a_mod
implicit none
public
integer :: global_a = 1
end module a_mod
"""

fcode_mod_b = """
module b_mod
use a_mod
implicit none
public
type type_b
integer :: val
end type type_b
end module b_mod
"""

if import_level == 'module':
module_import_stmt = "use b_mod, only: type_b, global_a"
routine_import_stmt = ""
elif import_level == 'subroutine':
module_import_stmt = ""
routine_import_stmt = "use b_mod, only: type_b, global_a"

fcode_mod_c = f"""
module c_mod
{module_import_stmt}
implicit none
contains
subroutine c(b)
{routine_import_stmt}
implicit none
type(type_b), intent(inout) :: b
b%val = global_a
end subroutine c
end module c_mod
"""

# Set-up paths and write sources
src_path = tmp_path/'src'
src_path.mkdir()
out_path = tmp_path/'build'
out_path.mkdir()

(src_path/'a.F90').write_text(fcode_mod_a)
(src_path/'b.F90').write_text(fcode_mod_b)
(src_path/'c.F90').write_text(fcode_mod_c)

# Create the Scheduler
config = SchedulerConfig.from_dict({
'default': {
'role': 'kernel',
'expand': True,
'strict': True,
'enable_imports': enable_imports
},
'routines': {'c': {'role': 'driver'}}
})
try:
scheduler = Scheduler(
paths=[src_path], config=config, frontend=frontend,
output_dir=out_path, xmods=[out_path]
)
except CalledProcessError as e:
if frontend == OMNI and not enable_imports:
# Without taking care of imports, OMNI will fail to parse the files
# because it is missing the xmod files for the header modules
pytest.xfail('Without parsing imports, OMNI does not have the xmod for imported modules')
raise e

# Check for all items in the dependency graph
expected_items = {'a_mod', 'b_mod', 'b_mod#type_b', 'c_mod#c'}
assert expected_items == {item.name for item in scheduler.items}

# Verify the type information for the imported symbols:
# They will have enriched information if the imports are enabled
# and deferred type otherwise
type_b = scheduler['b_mod#type_b'].ir
c_mod_c = scheduler['c_mod#c'].ir
var_map = CaseInsensitiveDict(
(v.name, v) for v in FindVariables().visit(c_mod_c.body)
)
global_a = var_map['global_a']
b_dtype = var_map['b'].type.dtype

if enable_imports:
assert global_a.type.dtype is BasicType.INTEGER
assert global_a.type.initial == '1'
assert b_dtype.typedef is type_b
else:
assert global_a.type.dtype is BasicType.DEFERRED
assert global_a.type.initial is None
assert b_dtype.typedef is BasicType.DEFERRED
11 changes: 7 additions & 4 deletions loki/program_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,14 +349,17 @@ def enrich(self, definitions, recurse=False):
# Take care of renaming upon import
local_name = symbol.name
remote_name = symbol.type.use_name or local_name
remote_node = module[remote_name]
try:
remote_node = module[remote_name]
except KeyError:
remote_node = None

if hasattr(remote_node, 'procedure_type'):
if remote_node and hasattr(remote_node, 'procedure_type'):
# This is a subroutine/function defined in the remote module
updated_symbol_attrs[local_name] = symbol.type.clone(
dtype=remote_node.procedure_type, imported=True, module=module
)
elif hasattr(remote_node, 'dtype'):
elif remote_node and hasattr(remote_node, 'dtype'):
# This is a derived type defined in the remote module
updated_symbol_attrs[local_name] = symbol.type.clone(
dtype=remote_node.dtype, imported=True, module=module
Expand All @@ -368,7 +371,7 @@ def enrich(self, definitions, recurse=False):
if getattr(type_.dtype, 'name') == remote_node.dtype.name
}
updated_symbol_attrs.update(variables_with_this_type)
elif hasattr(remote_node, 'type'):
elif remote_node and hasattr(remote_node, 'type'):
# This is a global variable or interface import
updated_symbol_attrs[local_name] = remote_node.type.clone(
imported=True, module=module, use_name=symbol.type.use_name
Expand Down
2 changes: 1 addition & 1 deletion loki/transformations/build_system/file_write.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def item_filter(self):
imports are honoured in the :any:`Scheduler` traversal.
"""
if self.include_module_var_imports:
return (ProcedureItem, ModuleItem) #GlobalVariableItem)
return (ProcedureItem, ModuleItem)
return ProcedureItem

def transform_file(self, sourcefile, **kwargs):
Expand Down
Loading
Loading