Skip to content

Commit

Permalink
Merge pull request #265 from ecmwf-ifs/naan-globalvar
Browse files Browse the repository at this point in the history
Skip driver routine in `GlobalVariableAnalysis`
  • Loading branch information
reuterbal authored Apr 11, 2024
2 parents e1c385a + f720c11 commit 9aa9a37
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 17 deletions.
40 changes: 34 additions & 6 deletions transformations/tests/test_data_offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,24 +388,20 @@ def test_global_variable_analysis(frontend, key, config, global_variable_analysi
nval_dim = '1:5'
nfld_data = set()
nval_data = set()
nval_offload = set()
nfld_offload = set()
else:
nfld_dim = 'nfld'
nval_dim = 'nval'
nfld_data = {('nfld', 'global_var_analysis_header_mod')}
nval_data = {('nval', 'global_var_analysis_header_mod')}
nval_offload = {'nval'}
nfld_offload = {'nfld'}

expected_trafo_data = {
'global_var_analysis_header_mod': {
'declares': {f'iarr({nfld_dim})', f'rarr({nval_dim}, {nfld_dim})'},
'offload': {f'iarr({nfld_dim})', f'rarr({nval_dim}, {nfld_dim})'} | nval_offload | nfld_offload,
'offload': {}
},
'global_var_analysis_data_mod': {
'declares': {'rdata(:, :, :)', 'tt'},
'offload': {'rdata(:, :, :)', 'tt', 'tt%vals'}
'offload': {}
},
'global_var_analysis_data_mod#some_routine': {'defines_symbols': set(), 'uses_symbols': set()},
'global_var_analysis_kernel_mod#kernel_a': {
Expand Down Expand Up @@ -454,6 +450,14 @@ def test_global_variable_offload(frontend, key, config, global_variable_analysis
'driver': {'role': 'driver'}
}

# OMNI handles array indices and parameters differently
if frontend == OMNI:
nfld_dim = '1:3'
nval_dim = '1:5'
else:
nfld_dim = 'nfld'
nval_dim = 'nval'

scheduler = Scheduler(
paths=(global_variable_analysis_code,), config=config, seed_routines='driver',
frontend=frontend, xmods=(global_variable_analysis_code,)
Expand All @@ -462,6 +466,30 @@ def test_global_variable_offload(frontend, key, config, global_variable_analysis
scheduler.process(GlobalVarOffloadTransformation(key=key))
driver = scheduler['#driver'].ir

if key is None:
key = GlobalVariableAnalysis._key

expected_trafo_data = {
'global_var_analysis_header_mod': {
'declares': {f'iarr({nfld_dim})', f'rarr({nval_dim}, {nfld_dim})'},
'offload': {f'iarr({nfld_dim})', f'rarr({nval_dim}, {nfld_dim})'}
},
'global_var_analysis_data_mod': {
'declares': {'rdata(:, :, :)', 'tt'},
'offload': {'rdata(:, :, :)', 'tt', 'tt%vals'}
},
}

# Verify module offload sets
for item in [scheduler['global_var_analysis_header_mod'], scheduler['global_var_analysis_data_mod']]:
for trafo_data_key, trafo_data_value in item.trafo_data[key].items():
assert (
sorted(
tuple(str(vv) for vv in v) if isinstance(v, tuple) else str(v)
for v in trafo_data_value
) == sorted(expected_trafo_data[item.name][trafo_data_key])
)

# Verify imports have been added to the driver
expected_imports = {
'global_var_analysis_header_mod': {'iarr', 'rarr'},
Expand Down
30 changes: 19 additions & 11 deletions transformations/transformations/data_offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,17 +323,6 @@ def _map_var_to_module(var):
_map_var_to_module(var) for var in defines_imported_symbols
}

# Propagate offload requirement to the items of the global variables
successors_map = CaseInsensitiveDict(
(item.name, item) for item in successors if isinstance(item, ModuleItem)
)
for var, module in chain(
item.trafo_data[self._key]['uses_symbols'],
item.trafo_data[self._key]['defines_symbols']
):
if successor := successors_map.get(module):
successor.trafo_data[self._key]['offload'].add(var)

# Amend analysis data with data from successors
# Note: This is a temporary workaround for the incomplete list of successor items
# provided by the current scheduler implementation
Expand Down Expand Up @@ -476,9 +465,28 @@ def transform_subroutine(self, routine, **kwargs):
"""
role = kwargs.get('role')
successors = kwargs.get('successors', ())
item = kwargs['item']

if role == 'driver':
self.process_driver(routine, successors)
elif role == 'kernel':
self.process_kernel(item, successors)

def process_kernel(self, item, successors):
"""
Propagate offload requirement to the items of the global variables
"""
successors_map = CaseInsensitiveDict(
(item.name, item) for item in successors if isinstance(item, ModuleItem)
)
for var, module in chain(
item.trafo_data[self._key]['uses_symbols'],
item.trafo_data[self._key]['defines_symbols']
):
if var.type.parameter:
continue
if successor := successors_map.get(module):
successor.trafo_data[self._key]['offload'].add(var)

def process_driver(self, routine, successors):
"""
Expand Down

0 comments on commit 9aa9a37

Please sign in to comment.