Skip to content

Commit

Permalink
DEPENDENCY TRAFO: statement functions included via c-style imports pr…
Browse files Browse the repository at this point in the history
…eserved
  • Loading branch information
awnawab committed Mar 26, 2024
1 parent 023b46b commit 86c2f97
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 4 deletions.
4 changes: 2 additions & 2 deletions loki/transform/dependency_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ def rename_imports(self, source, imports, targets=None):
for im in imports:
if im.c_import:
target_symbol = im.module.split('.')[0].lower()
if targets and target_symbol.lower() in targets:
if targets and target_symbol.lower() in targets and 'intfb' in im.module.lower():
# Modify the the basename of the C-style header import
s = '.'.join(im.module.split('.')[1:])
im._update(module=f'{target_symbol}{self.suffix}.{s}')
Expand Down Expand Up @@ -487,7 +487,7 @@ def _update_item(proc_name, module_name):
for im in imports:
if im.c_import:
target_symbol = im.module.split('.')[0].lower()
if targets and target_symbol.lower() in targets:
if targets and target_symbol.lower() in targets and 'intfb' in im.module.lower():
# Create a new module import with explicitly qualified symbol
modname = f'{target_symbol}{self.module_suffix}'
_update_item(target_symbol.lower(), modname)
Expand Down
10 changes: 8 additions & 2 deletions tests/test_transform_dependency.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@ def test_dependency_transformation_header_includes(here, frontend):
INTEGER, INTENT(INOUT) :: a, b, c
#include "kernel.intfb.h"
#include "kernel.func.h"
CALL kernel(a, b ,c)
END SUBROUTINE driver
Expand Down Expand Up @@ -245,6 +246,9 @@ def test_dependency_transformation_header_includes(here, frontend):
assert '#include "kernel.intfb.h"' not in driver.to_fortran()
assert '#include "kernel_test.intfb.h"' in driver.to_fortran()

# Check that imported function was not modified
assert '#include "kernel.func.h"' in driver.to_fortran()

# Check that header file was generated and clean up
assert header_file.exists()
header_file.unlink()
Expand All @@ -262,6 +266,7 @@ def test_dependency_transformation_module_wrap(frontend, use_scheduler, tempdir,
SUBROUTINE driver(a, b, c)
INTEGER, INTENT(INOUT) :: a, b, c
#include "kernel.func.h"
#include "kernel.intfb.h"
CALL kernel(a, b ,c)
Expand Down Expand Up @@ -320,10 +325,11 @@ def test_dependency_transformation_module_wrap(frontend, use_scheduler, tempdir,
calls = FindNodes(CallStatement).visit(driver['driver'].body)
assert len(calls) == 1
assert calls[0].name == 'kernel_test'
imports = FindNodes(Import).visit(driver['driver'].spec)
assert len(imports) == 1
imports = FindNodes(Import).visit(driver['driver'].ir)
assert len(imports) == 2
assert imports[0].module == 'kernel_test_mod'
assert 'kernel_test' in [str(s) for s in imports[0].symbols]
assert imports[1].module == 'kernel.func.h'


@pytest.mark.parametrize('frontend', available_frontends())
Expand Down

0 comments on commit 86c2f97

Please sign in to comment.