diff --git a/loki/transformations/data_offload.py b/loki/transformations/data_offload.py index ebd587280..1c662d36c 100644 --- a/loki/transformations/data_offload.py +++ b/loki/transformations/data_offload.py @@ -10,20 +10,23 @@ from loki.analyse import dataflow_analysis_attached from loki.batch import Transformation, ProcedureItem, ModuleItem -from loki.expression import Scalar, Array +from loki.expression import Scalar, Array, symbols as sym from loki.ir import ( FindNodes, PragmaRegion, CallStatement, Pragma, Import, Comment, Transformer, pragma_regions_attached, get_pragma_parameters, FindInlineCalls, SubstituteExpressions ) -from loki.logging import warning +from loki.logging import warning, error from loki.tools import as_tuple, flatten, CaseInsensitiveDict, CaseInsensitiveDefaultDict from loki.types import BasicType, DerivedType - +from loki.transformations.parallel import ( + FieldAPITransferType, field_get_device_data, field_sync_host, remove_field_api_view_updates +) __all__ = [ 'DataOffloadTransformation', 'GlobalVariableAnalysis', - 'GlobalVarOffloadTransformation', 'GlobalVarHoistTransformation' + 'GlobalVarOffloadTransformation', 'GlobalVarHoistTransformation', + 'FieldOffloadTransformation' ] @@ -38,6 +41,8 @@ class DataOffloadTransformation(Transformation): ---------- remove_openmp : bool Remove any existing OpenMP pragmas inside the marked region. + present_on_device : bool + Assume arrays are already offloaded and present on device" assume_deviceptr : bool Mark all offloaded arrays as true device-pointers if data offload is being managed outside of structured OpenACC data regions. @@ -49,6 +54,12 @@ def __init__(self, **kwargs): self.has_data_regions = False self.remove_openmp = kwargs.get('remove_openmp', False) self.assume_deviceptr = kwargs.get('assume_deviceptr', False) + self.present_on_device = kwargs.get('present_on_device', False) + + if self.assume_deviceptr and not self.present_on_device: + error("[Loki] Data offload: Can't assume device pointer arrays without arrays being marked" + + "present on device.") + raise RuntimeError def transform_subroutine(self, routine, **kwargs): """ @@ -148,14 +159,22 @@ def insert_data_offload_pragmas(self, routine, targets): outargs = tuple(dict.fromkeys(outargs)) inoutargs = tuple(dict.fromkeys(inoutargs)) - # Now geenerate the pre- and post pragmas (OpenACC) - if self.assume_deviceptr: - offload_args = inargs + outargs + inoutargs - if offload_args: - deviceptr = f' deviceptr({", ".join(offload_args)})' + # Now generate the pre- and post pragmas (OpenACC) + if self.present_on_device: + if self.assume_deviceptr: + offload_args = inargs + outargs + inoutargs + if offload_args: + deviceptr = f' deviceptr({", ".join(offload_args)})' + else: + deviceptr = '' + pragma = Pragma(keyword='acc', content=f'data{deviceptr}') else: - deviceptr = '' - pragma = Pragma(keyword='acc', content=f'data{deviceptr}') + offload_args = inargs + outargs + inoutargs + if offload_args: + present = f' present({", ".join(offload_args)})' + else: + present = '' + pragma = Pragma(keyword='acc', content=f'data{present}') else: copyin = f'copyin({", ".join(inargs)})' if inargs else '' copy = f'copy({", ".join(inoutargs)})' if inoutargs else '' @@ -908,3 +927,228 @@ def _append_routine_arguments(self, routine, item): )) for arg in new_arguments ] routine.arguments += tuple(sorted(new_arguments, key=lambda symbol: symbol.name)) + + +def find_target_calls(region, targets): + """ + Returns a list of all calls to targets inside the region. + + Parameters + ---------- + :region: :any:`PragmaRegion` + :targets: collection of :any:`Subroutine` + Iterable object of subroutines or functions called + :returns: list of :any:`CallStatement` + """ + calls = FindNodes(CallStatement).visit(region) + calls = [c for c in calls if str(c.name).lower() in targets] + return calls + + +class FieldOffloadTransformation(Transformation): + """ + + Transformation to offload arrays owned by Field API fields to the device. **This transformation is IFS specific.** + + The transformation assumes that fields are wrapped in derived types specified in + ``field_group_types`` and will only offload arrays that are members of such derived types. + In the process this transformation removes calls to Field API ``update_view`` and adds + declarations for the device pointers to the driver subroutine. + + The transformation acts on ``!$loki data`` regions and offloads all :any:`Array` + symbols that satisfy the following conditions: + + 1. The array is a member of an object that is of type specified in ``field_group_types``. + + 2. The array is passed as a parameter to at least one of the kernel targets passed to ``transform_subroutine``. + + Parameters + ---------- + devptr_prefix: str, optional + The prefix of device pointers added by this transformation (defaults to ``'loki_devptr_'``). + field_froup_types: list or tuple of str, optional + Names of the field group types with members that may be offloaded (defaults to ``['']``). + offload_index: str, optional + Names of index variable to inject in the outmost dimension of offloaded arrays in the kernel + calls (defaults to ``'IBL'``). + """ + + class FieldPointerMap: + """ + Helper class to :any:`FieldOffloadTransformation` that is used to store arrays passed to + target kernel calls and the corresponding device pointers added by the transformation. + The pointer/array variable pairs are exposed through the class properties, based on + the intent of the kernel argument. + """ + def __init__(self, devptrs, inargs, inoutargs, outargs): + self.inargs = inargs + self.inoutargs = inoutargs + self.outargs = outargs + self.devptrs = devptrs + + + @property + def in_pairs(self): + """ + Iterator that yields array/pointer pairs for kernel arguments of intent(in). + + Yields + ______ + :any:`Array` + Original kernel call argument + :any:`Array` + Corresponding device pointer added by the transformation. + """ + for i, inarg in enumerate(self.inargs): + yield inarg, self.devptrs[i] + + @property + def inout_pairs(self): + """ + Iterator that yields array/pointer pairs for arguments with intent(inout). + + Yields + ______ + :any:`Array` + Original kernel call argument + :any:`Array` + Corresponding device pointer added by the transformation. + """ + start = len(self.inargs) + for i, inoutarg in enumerate(self.inoutargs): + yield inoutarg, self.devptrs[i+start] + + @property + def out_pairs(self): + """ + Iterator that yields array/pointer pairs for arguments with intent(out) + + Yields + ______ + :any:`Array` + Original kernel call argument + :any:`Array` + Corresponding device pointer added by the transformation. + """ + + start = len(self.inargs)+len(self.inoutargs) + for i, outarg in enumerate(self.outargs): + yield outarg, self.devptrs[i+start] + + + def __init__(self, devptr_prefix=None, field_group_types=None, offload_index=None): + self.deviceptr_prefix = 'loki_devptr_' if devptr_prefix is None else devptr_prefix + field_group_types = [''] if field_group_types is None else field_group_types + self.field_group_types = tuple(typename.lower() for typename in field_group_types) + self.offload_index = 'IBL' if offload_index is None else offload_index + + def transform_subroutine(self, routine, **kwargs): + role = kwargs['role'] + targets = as_tuple(kwargs.get('targets'), (None)) + if role == 'driver': + self.process_driver(routine, targets) + + def process_driver(self, driver, targets): + remove_field_api_view_updates(driver, self.field_group_types) + with pragma_regions_attached(driver): + for region in FindNodes(PragmaRegion).visit(driver.body): + # Only work on active `!$loki data` regions + if not DataOffloadTransformation._is_active_loki_data_region(region, targets): + continue + kernel_calls = find_target_calls(region, targets) + offload_variables = self.find_offload_variables(driver, kernel_calls) + device_ptrs = self._declare_device_ptrs(driver, offload_variables) + offload_map = self.FieldPointerMap(device_ptrs, *offload_variables) + self._add_field_offload_calls(driver, region, offload_map) + self._replace_kernel_args(driver, kernel_calls, offload_map) + + def find_offload_variables(self, driver, calls): + inargs = () + inoutargs = () + outargs = () + + for call in calls: + if call.routine is BasicType.DEFERRED: + error(f'[Loki] Data offload: Routine {driver.name} has not been enriched ' + + f'in {str(call.name).lower()}') + raise RuntimeError + for param, arg in call.arg_iter(): + if not isinstance(param, Array): + continue + try: + parent = arg.parent + if parent.type.dtype.name.lower() not in self.field_group_types: + warning(f'[Loki] Data offload: The parent object {parent.name} of type ' + + f'{parent.type.dtype} is not in the list of field wrapper types') + continue + except AttributeError: + warning(f'[Loki] Data offload: Raw array object {arg.name} encountered in' + + f' {driver.name} that is not wrapped by a Field API object') + continue + + if param.type.intent.lower() == 'in': + inargs += (arg, ) + if param.type.intent.lower() == 'inout': + inoutargs += (arg, ) + if param.type.intent.lower() == 'out': + outargs += (arg, ) + + inoutargs += tuple(v for v in inargs if v in outargs) + inargs = tuple(v for v in inargs if v not in inoutargs) + outargs = tuple(v for v in outargs if v not in inoutargs) + + inargs = tuple(set(inargs)) + inoutargs = tuple(set(inoutargs)) + outargs = tuple(set(outargs)) + return inargs, inoutargs, outargs + + + def _declare_device_ptrs(self, driver, offload_variables): + device_ptrs = tuple(self._devptr_from_array(driver, a) for a in chain(*offload_variables)) + driver.variables += device_ptrs + return device_ptrs + + def _devptr_from_array(self, driver, a: sym.Array): + """ + Returns a contiguous pointer :any:`Variable` with types matching the array a + """ + shape = (sym.RangeIndex((None, None)),) * (len(a.shape)+1) + devptr_type = a.type.clone(pointer=True, contiguous=True, shape=shape, intent=None) + base_name = a.name if a.parent is None else '_'.join(a.name.split('%')) + devptr_name = self.deviceptr_prefix + base_name + if devptr_name in driver.variable_map: + warning(f'[Loki] Data offload: The routine {driver.name} already has a ' + + f'variable named {devptr_name}') + devptr = sym.Variable(name=devptr_name, type=devptr_type, dimensions=shape) + return devptr + + def _add_field_offload_calls(self, driver, region, offload_map): + host_to_device = tuple(field_get_device_data(self._get_field_ptr_from_view(inarg), devptr, + FieldAPITransferType.READ_ONLY, driver) for inarg, devptr in offload_map.in_pairs) + host_to_device += tuple(field_get_device_data(self._get_field_ptr_from_view(inarg), devptr, + FieldAPITransferType.READ_WRITE, driver) for inarg, devptr in offload_map.inout_pairs) + host_to_device += tuple(field_get_device_data(self._get_field_ptr_from_view(inarg), devptr, + FieldAPITransferType.READ_WRITE, driver) for inarg, devptr in offload_map.out_pairs) + device_to_host = tuple(field_sync_host(self._get_field_ptr_from_view(inarg), driver) + for inarg, _ in chain(offload_map.inout_pairs, offload_map.out_pairs)) + update_map = {region: host_to_device + (region,) + device_to_host} + Transformer(update_map, inplace=True).visit(driver.body) + + def _get_field_ptr_from_view(self, field_view): + type_chain = field_view.name.split('%') + field_type_name = 'F_' + type_chain[-1] + return field_view.parent.get_derived_type_member(field_type_name) + + def _replace_kernel_args(self, driver, kernel_calls, offload_map): + change_map = {} + offload_idx_expr = driver.variable_map[self.offload_index] + for arg, devptr in chain(offload_map.in_pairs, offload_map.inout_pairs, offload_map.out_pairs): + if len(arg.dimensions) != 0: + dims = arg.dimensions + (offload_idx_expr,) + else: + dims = (sym.RangeIndex((None, None)),) * (len(devptr.shape)-1) + (offload_idx_expr,) + change_map[arg] = devptr.clone(dimensions=dims) + + arg_transformer = SubstituteExpressions(change_map, inplace=True) + for call in kernel_calls: + arg_transformer.visit(call) diff --git a/loki/transformations/parallel/field_api.py b/loki/transformations/parallel/field_api.py index e1e06b491..4b5705770 100644 --- a/loki/transformations/parallel/field_api.py +++ b/loki/transformations/parallel/field_api.py @@ -9,16 +9,18 @@ Transformation utilities to manage and inject FIELD-API boilerplate code. """ +from enum import Enum from loki.expression import symbols as sym from loki.ir import ( nodes as ir, FindNodes, FindVariables, Transformer ) +from loki.scope import Scope from loki.logging import warning from loki.tools import as_tuple - __all__ = [ - 'remove_field_api_view_updates', 'add_field_api_view_updates' + 'remove_field_api_view_updates', 'add_field_api_view_updates', 'get_field_type', + 'field_get_device_data', 'field_sync_host', 'FieldAPITransferType' ] @@ -150,3 +152,85 @@ def visit_Loop(self, loop, **kwargs): # pylint: disable=unused-argument return loop routine.body = InsertFieldAPIViewsTransformer().visit(routine.body, scope=routine) + + +def get_field_type(a: sym.Array) -> sym.DerivedType: + """ + Returns the corresponding FIELD API type for an array. + + This function is IFS specific and assumes that the + type is an array declared with one of the IFS type specifiers, e.g. KIND=JPRB + """ + type_map = ["jprb", + "jpit", + "jpis", + "jpim", + "jpib", + "jpia", + "jprt", + "jprs", + "jprm", + "jprd", + "jplm"] + type_name = a.type.kind.name + + assert type_name.lower() in type_map, ('Error array type kind is: ' + f'"{type_name}" which is not a valid IFS type specifier') + rank = len(a.shape) + field_type = sym.DerivedType(name="field_" + str(rank) + type_name[2:4].lower()) + return field_type + + + +class FieldAPITransferType(Enum): + READ_ONLY = 1 + READ_WRITE = 2 + WRITE_ONLY = 3 + + +def field_get_device_data(field_ptr, dev_ptr, transfer_type: FieldAPITransferType, scope: Scope): + """ + Utility function to generate a :any:`CallStatement` corresponding to a Field API + ``GET_DEVICE_DATA`` call. + + Parameters + ---------- + field_ptr: pointer to field object + Pointer to the field to call ``GET_DEVICE_DATA`` from. + dev_ptr: :any:`Array` + Device pointer array + transfer_type: :any:`FieldAPITransferType` + Field API transfer type to determine which ``GET_DEVICE_DATA`` method to call. + scope: :any:`Scope` + Scope of the created :any:`CallStatement` + """ + if not isinstance(transfer_type, FieldAPITransferType): + raise TypeError(f"transfer_type must be of type FieldAPITransferType, but is of type {type(transfer_type)}") + if transfer_type == FieldAPITransferType.READ_ONLY: + suffix = 'RDONLY' + elif transfer_type == FieldAPITransferType.READ_WRITE: + suffix = 'RDWR' + elif transfer_type == FieldAPITransferType.WRITE_ONLY: + suffix = 'WRONLY' + else: + suffix = '' + procedure_name = 'GET_DEVICE_DATA_' + suffix + return ir.CallStatement(name=sym.ProcedureSymbol(procedure_name, parent=field_ptr, scope=scope), + arguments=(dev_ptr.clone(dimensions=None),), ) + + +def field_sync_host(field_ptr, scope): + """ + Utility function to generate a :any:`CallStatement` corresponding to a Field API + ``SYNC_HOST`` call. + + Parameters + ---------- + field_ptr: pointer to field object + Pointer to the field to call ``SYNC_HOST`` from. + scope: :any:`Scope` + Scope of the created :any:`CallStatement` + """ + + procedure_name = 'SYNC_HOST_RDWR' + return ir.CallStatement(name=sym.ProcedureSymbol(procedure_name, parent=field_ptr, scope=scope), arguments=()) diff --git a/loki/transformations/parallel/tests/test_field_api.py b/loki/transformations/parallel/tests/test_field_api.py index b752dec85..5e08f9c4c 100644 --- a/loki/transformations/parallel/tests/test_field_api.py +++ b/loki/transformations/parallel/tests/test_field_api.py @@ -10,11 +10,14 @@ from loki import Subroutine, Module, Dimension from loki.frontend import available_frontends, OMNI from loki.ir import nodes as ir, FindNodes -from loki.logging import WARNING - +from loki.expression import symbols as sym +from loki.scope import Scope from loki.transformations.parallel import ( - remove_field_api_view_updates, add_field_api_view_updates + remove_field_api_view_updates, add_field_api_view_updates, get_field_type, + field_get_device_data, FieldAPITransferType ) +from loki.types import BasicType, SymbolAttributes +from loki.logging import WARNING @pytest.mark.parametrize('frontend', available_frontends( @@ -136,3 +139,63 @@ def test_field_api_add_view_updates(frontend): assert calls[3].name == 'STATE%UPDATE_VIEW' and calls[3].arguments == ('IBL',) assert len(FindNodes(ir.Loop).visit(routine.body)) == 1 + + +def test_get_field_type(): + type_map = ["jprb", + "jpit", + "jpis", + "jpim", + "jpib", + "jpia", + "jprt", + "jprs", + "jprm", + "jprd", + "jplm"] + field_types = [ + "field_1rb", "field_2rb", "field_3rb", + "field_1it", "field_2it", "field_3it", + "field_1is", "field_2is", "field_3is", + "field_1im", "field_2im", "field_3im", + "field_1ib", "field_2ib", "field_3ib", + "field_1ia", "field_2ia", "field_3ia", + "field_1rt", "field_2rt", "field_3rt", + "field_1rs", "field_2rs", "field_3rs", + "field_1rm", "field_2rm", "field_3rm", + "field_1rd", "field_2rd", "field_3rd", + "field_1lm", "field_2lm", "field_3lm", + ] + + def generate_fields(types): + generated = [] + for type_name in types: + for dim in range(1, 4): + shape = tuple(None for _ in range(dim)) + a = sym.Variable(name='test_array', + type=SymbolAttributes(BasicType.REAL, + shape=shape, + kind=sym.Variable(name=type_name))) + generated.append(get_field_type(a)) + return generated + + generated = generate_fields(type_map) + for field, field_name in zip(generated, field_types): + assert isinstance(field, sym.DerivedType) and field.name == field_name + + generated = generate_fields([t.upper() for t in type_map]) + for field, field_name in zip(generated, field_types): + assert isinstance(field, sym.DerivedType) and field.name == field_name + + +def test_field_get_device_data(): + scope = Scope() + fptr = sym.Variable(name='fptr_var') + dev_ptr = sym.Variable(name='data_var') + for fttype in FieldAPITransferType: + get_dev_data_call = field_get_device_data(fptr, dev_ptr, fttype, scope) + assert isinstance(get_dev_data_call, ir.CallStatement) + assert get_dev_data_call.name.parent == fptr + with pytest.raises(TypeError): + _ = field_get_device_data(fptr, dev_ptr, "none_transfer_type", scope) + diff --git a/loki/transformations/tests/test_data_offload.py b/loki/transformations/tests/test_data_offload.py index 81dad4dee..c6cb5cf11 100644 --- a/loki/transformations/tests/test_data_offload.py +++ b/loki/transformations/tests/test_data_offload.py @@ -9,24 +9,25 @@ import pytest from loki import ( - Sourcefile, Scheduler, FindInlineCalls + Sourcefile, Scheduler, FindInlineCalls, warning ) from loki.frontend import available_frontends, OMNI +from loki.logging import log_levels, logger from loki.ir import ( FindNodes, Pragma, PragmaRegion, Loop, CallStatement, Import, pragma_regions_attached, get_pragma_parameters ) +import loki.expression.symbols as sym +from loki.module import Module from loki.transformations import ( DataOffloadTransformation, GlobalVariableAnalysis, - GlobalVarOffloadTransformation, GlobalVarHoistTransformation + GlobalVarOffloadTransformation, GlobalVarHoistTransformation, FieldOffloadTransformation ) - @pytest.fixture(scope='module', name='here') def fixture_here(): return Path(__file__).parent - @pytest.fixture(name='config') def fixture_config(): """ @@ -42,10 +43,51 @@ def fixture_config(): }, } +@pytest.fixture(name="parkind_mod") +def fixture_parkind_mod(tmp_path, frontend): + fcode = """ + module parkind1 + integer, parameter :: jprb=4 + end module + """ + return Module.from_source(fcode, frontend=frontend, xmods=[tmp_path]) + +@pytest.fixture(name="field_module") +def fixture_field_module(tmp_path, frontend): + fcode = """ + module field_module + implicit none + + type field_2rb + real, pointer :: f_ptr(:,:,:) + end type field_2rb + + type field_3rb + real, pointer :: f_ptr(:,:,:) + contains + procedure :: update_view + end type field_3rb + + type field_4rb + real, pointer :: f_ptr(:,:,:) + contains + procedure :: update_view + end type field_4rb + + contains + subroutine update_view(self, idx) + class(field_3rb), intent(in) :: self + integer, intent(in) :: idx + end subroutine + end module + """ + return Module.from_source(fcode, frontend=frontend, xmods=[tmp_path]) + @pytest.mark.parametrize('frontend', available_frontends()) @pytest.mark.parametrize('assume_deviceptr', [True, False]) -def test_data_offload_region_openacc(frontend, assume_deviceptr): +@pytest.mark.parametrize('present_on_device', [True, False]) +def test_data_offload_region_openacc(caplog, frontend, assume_deviceptr, present_on_device): """ Test the creation of a simple device data offload region (`!$acc update`) from a `!$loki data` region with a single @@ -84,9 +126,21 @@ def test_data_offload_region_openacc(frontend, assume_deviceptr): driver = Sourcefile.from_source(fcode_driver, frontend=frontend)['driver_routine'] kernel = Sourcefile.from_source(fcode_kernel, frontend=frontend)['kernel_routine'] driver.enrich(kernel) - - driver.apply(DataOffloadTransformation(assume_deviceptr=assume_deviceptr), role='driver', - targets=['kernel_routine']) + + if assume_deviceptr and not present_on_device: + caplog.clear() + with caplog.at_level(log_levels['ERROR']): + with pytest.raises(RuntimeError): + data_offload_trafo = DataOffloadTransformation(assume_deviceptr=assume_deviceptr, + present_on_device=present_on_device) + assert len(caplog.records) == 1 + assert ("[Loki] Data offload: Can't assume device pointer arrays without arrays being marked" + + "present on device.") in caplog.records[0].message + return + + data_offload_trafo = DataOffloadTransformation(assume_deviceptr=assume_deviceptr, + present_on_device=present_on_device) + driver.apply(data_offload_trafo, role='driver', targets=['kernel_routine']) pragmas = FindNodes(Pragma).visit(driver.body) assert len(pragmas) == 2 @@ -95,6 +149,10 @@ def test_data_offload_region_openacc(frontend, assume_deviceptr): assert 'deviceptr' in pragmas[0].content params = get_pragma_parameters(pragmas[0], only_loki_pragmas=False) assert all(var in params['deviceptr'] for var in ('a', 'b', 'c')) + elif present_on_device: + assert 'present' in pragmas[0].content + params = get_pragma_parameters(pragmas[0], only_loki_pragmas=False) + assert all(var in params['present'] for var in ('a', 'b', 'c')) else: transformed = driver.to_fortran() assert 'copyin( a )' in transformed @@ -805,3 +863,528 @@ def test_transformation_global_var_derived_type_hoist(here, config, frontend, ho assert kernel.variable_map['p'].type.dtype.name == 'point' assert kernel.variable_map['p0'].type.intent == 'in' assert kernel.variable_map['p0'].type.dtype.name == 'point' + + +@pytest.mark.parametrize('frontend', available_frontends()) +def test_field_offload(frontend, parkind_mod, field_module, tmp_path): + fcode = """ + module driver_mod + use parkind1, only: jprb + use field_module, only: field_2rb, field_3rb + implicit none + + type state_type + real(kind=jprb), dimension(10,10), pointer :: a, b, c + class(field_3rb), pointer :: f_a, f_b, f_c + contains + procedure :: update_view => state_update_view + end type state_type + + contains + + subroutine state_update_view(self, idx) + class(state_type), intent(in) :: self + integer, intent(in) :: idx + end subroutine + + subroutine kernel_routine(nlon, nlev, a, b, c) + integer, intent(in) :: nlon, nlev + real(kind=jprb), intent(in) :: a(nlon,nlev) + real(kind=jprb), intent(inout) :: b(nlon,nlev) + real(kind=jprb), intent(out) :: c(nlon,nlev) + integer :: i, j + + do j=1, nlon + do i=1, nlev + b(i,j) = a(i,j) + 0.1 + c(i,j) = 0.1 + end do + end do + end subroutine kernel_routine + + subroutine driver_routine(nlon, nlev, state) + integer, intent(in) :: nlon, nlev + type(state_type), intent(inout) :: state + integer :: i + + !$loki data + do i=1,nlev + call state%update_view(i) + call kernel_routine(nlon, nlev, state%a, state%b, state%c) + end do + !$loki end data + + end subroutine driver_routine + end module driver_mod + """ + driver_mod = Sourcefile.from_source(fcode, frontend=frontend, xmods=[tmp_path])['driver_mod'] + driver = driver_mod['driver_routine'] + deviceptr_prefix = 'loki_devptr_prefix_' + driver.apply(FieldOffloadTransformation(devptr_prefix=deviceptr_prefix, + offload_index='i', + field_group_types=['state_type']), + role='driver', + targets=['kernel_routine']) + + calls = FindNodes(CallStatement).visit(driver.body) + kernel_call = next(c for c in calls if c.name=='kernel_routine') + + # verify that field offloads are generated properly + in_calls = [c for c in calls if 'get_device_data_rdonly' in c.name.name.lower()] + assert len(in_calls) == 1 + inout_calls = [c for c in calls if 'get_device_data_rdwr' in c.name.name.lower()] + assert len(inout_calls) == 2 + # verify that field sync host calls are generated properly + sync_calls = [c for c in calls if 'sync_host_rdwr' in c.name.name.lower()] + assert len(sync_calls) == 2 + + # verify that data offload pragmas remain + pragmas = FindNodes(Pragma).visit(driver.body) + assert len(pragmas) == 2 + assert all(p.keyword=='loki' and p.content==c for p, c in zip(pragmas, ['data', 'end data'])) + + # verify that new pointer variables are created and used in driver calls + for var in ['state_a', 'state_b', 'state_c']: + name = deviceptr_prefix + var + assert name in driver.variable_map + devptr = driver.variable_map[name] + assert isinstance(devptr, sym.Array) + assert len(devptr.shape) == 3 + assert devptr.name in (arg.name for arg in kernel_call.arguments) + + +@pytest.mark.parametrize('frontend', available_frontends()) +def test_field_offload_slices(frontend, parkind_mod, field_module, tmp_path): + fcode = """ + module driver_mod + use parkind1, only: jprb + use field_module, only: field_4rb + implicit none + + type state_type + real(kind=jprb), dimension(10,10,10), pointer :: a, b, c, d + class(field_4rb), pointer :: f_a, f_b, f_c, f_d + contains + procedure :: update_view => state_update_view + end type state_type + + contains + + subroutine state_update_view(self, idx) + class(state_type), intent(in) :: self + integer, intent(in) :: idx + end subroutine + + subroutine kernel_routine(nlon, nlev, a, b, c, d) + integer, intent(in) :: nlon, nlev + real(kind=jprb), intent(in) :: a(nlon,nlev,nlon) + real(kind=jprb), intent(inout) :: b(nlon,nlev) + real(kind=jprb), intent(out) :: c(nlon) + real(kind=jprb), intent(in) :: d(nlon,nlev,nlon) + integer :: i, j + end subroutine kernel_routine + + subroutine driver_routine(nlon, nlev, state) + integer, intent(in) :: nlon, nlev + type(state_type), intent(inout) :: state + integer :: i + !$loki data + do i=1,nlev + call kernel_routine(nlon, nlev, state%a(:,:,1), state%b(:,1,1), state%c(1,1,1), state%d) + end do + !$loki end data + + end subroutine driver_routine + end module driver_mod + """ + driver_mod = Sourcefile.from_source(fcode, frontend=frontend, xmods=[tmp_path])['driver_mod'] + driver = driver_mod['driver_routine'] + deviceptr_prefix = 'loki_devptr_prefix_' + driver.apply(FieldOffloadTransformation(devptr_prefix=deviceptr_prefix, + offload_index='i', + field_group_types=['state_type']), + role='driver', + targets=['kernel_routine']) + + calls = FindNodes(CallStatement).visit(driver.body) + kernel_call = next(c for c in calls if c.name=='kernel_routine') + # verify that new pointer variables are created and used in driver calls + for var, rank in zip(['state_d', 'state_a', 'state_b', 'state_c',], [4, 3, 2, 1]): + name = deviceptr_prefix + var + assert name in driver.variable_map + devptr = driver.variable_map[name] + assert isinstance(devptr, sym.Array) + assert len(devptr.shape) == 4 + assert devptr.name in (arg.name for arg in kernel_call.arguments) + arg = next(arg for arg in kernel_call.arguments if devptr.name in arg.name) + assert arg.dimensions == ((sym.RangeIndex((None,None)),)*(rank-1) + + (sym.IntLiteral(1),)*(4-rank) + + (sym.Scalar(name='i'),)) + + + + +@pytest.mark.parametrize('frontend', available_frontends()) +def test_field_offload_multiple_calls(frontend, parkind_mod, field_module, tmp_path): + fcode = """ + module driver_mod + use parkind1, only: jprb + use field_module, only: field_2rb, field_3rb + implicit none + + type state_type + real(kind=jprb), dimension(10,10), pointer :: a, b, c + class(field_3rb), pointer :: f_a, f_b, f_c + contains + procedure :: update_view => state_update_view + end type state_type + + contains + + subroutine state_update_view(self, idx) + class(state_type), intent(in) :: self + integer, intent(in) :: idx + end subroutine + + subroutine kernel_routine(nlon, nlev, a, b, c) + integer, intent(in) :: nlon, nlev + real(kind=jprb), intent(in) :: a(nlon,nlev) + real(kind=jprb), intent(inout) :: b(nlon,nlev) + real(kind=jprb), intent(out) :: c(nlon,nlev) + integer :: i, j + + do j=1, nlon + do i=1, nlev + b(i,j) = a(i,j) + 0.1 + c(i,j) = 0.1 + end do + end do + end subroutine kernel_routine + + subroutine driver_routine(nlon, nlev, state) + integer, intent(in) :: nlon, nlev + type(state_type), intent(inout) :: state + integer :: i + + !$loki data + do i=1,nlev + call state%update_view(i) + + call kernel_routine(nlon, nlev, state%a, state%b, state%c) + + call kernel_routine(nlon, nlev, state%a, state%b, state%c) + end do + !$loki end data + + end subroutine driver_routine + end module driver_mod + """ + + driver_mod = Sourcefile.from_source(fcode, frontend=frontend, xmods=[tmp_path])['driver_mod'] + driver = driver_mod['driver_routine'] + deviceptr_prefix = 'loki_devptr_prefix_' + driver.apply(FieldOffloadTransformation(devptr_prefix=deviceptr_prefix, + offload_index='i', + field_group_types=['state_type']), + role='driver', + targets=['kernel_routine']) + calls = FindNodes(CallStatement).visit(driver.body) + kernel_calls = [c for c in calls if c.name=='kernel_routine'] + + # verify that field offloads are generated properly + in_calls = [c for c in calls if 'get_device_data_rdonly' in c.name.name.lower()] + assert len(in_calls) == 1 + inout_calls = [c for c in calls if 'get_device_data_rdwr' in c.name.name.lower()] + assert len(inout_calls) == 2 + # verify that field sync host calls are generated properly + sync_calls = [c for c in calls if 'sync_host_rdwr' in c.name.name.lower()] + assert len(sync_calls) == 2 + + # verify that data offload pragmas remain + pragmas = FindNodes(Pragma).visit(driver.body) + assert len(pragmas) == 2 + assert all(p.keyword=='loki' and p.content==c for p, c in zip(pragmas, ['data', 'end data'])) + + # verify that new pointer variables are created and used in driver calls + for var in ['state_a', 'state_b', 'state_c']: + name = deviceptr_prefix + var + assert name in driver.variable_map + devptr = driver.variable_map[name] + assert isinstance(devptr, sym.Array) + assert len(devptr.shape) == 3 + assert devptr.name in (arg.name for kernel_call in kernel_calls for arg in kernel_call.arguments) + + +@pytest.mark.parametrize('frontend', available_frontends()) +def test_field_offload_no_targets(frontend, parkind_mod, field_module, tmp_path): + fother = """ + module another_module + implicit none + contains + subroutine another_kernel(nlon, nlev, a, b, c) + integer, intent(in) :: nlon, nlev + real, intent(in) :: a(nlon,nlev) + real, intent(inout) :: b(nlon,nlev) + real, intent(out) :: c(nlon,nlev) + integer :: i, j + end subroutine + end module + """ + fcode = """ + module driver_mod + use parkind1, only: jprb + use field_module, only: field_2rb, field_3rb + use another_module, only: another_kernel + + implicit none + + type state_type + real(kind=jprb), dimension(10,10), pointer :: a, b, c + class(field_3rb), pointer :: f_a, f_b, f_c + contains + procedure :: update_view => state_update_view + end type state_type + + contains + + subroutine state_update_view(self, idx) + class(state_type), intent(in) :: self + integer, intent(in) :: idx + end subroutine + + subroutine kernel_routine(nlon, nlev, a, b, c) + integer, intent(in) :: nlon, nlev + real(kind=jprb), intent(in) :: a(nlon,nlev) + real(kind=jprb), intent(inout) :: b(nlon,nlev) + real(kind=jprb), intent(out) :: c(nlon,nlev) + integer :: i, j + + do j=1, nlon + do i=1, nlev + b(i,j) = a(i,j) + 0.1 + c(i,j) = 0.1 + end do + end do + end subroutine kernel_routine + + subroutine driver_routine(nlon, nlev, state) + integer, intent(in) :: nlon, nlev + type(state_type), intent(inout) :: state + integer :: i + + !$loki data + do i=1,nlev + call state%update_view(i) + call another_kernel(nlon, state%a, state%b, state%c) + end do + !$loki end data + + end subroutine driver_routine + end module driver_mod + """ + + Sourcefile.from_source(fother, frontend=frontend, xmods=[tmp_path]) + driver_mod = Sourcefile.from_source(fcode, frontend=frontend, xmods=[tmp_path])['driver_mod'] + driver = driver_mod['driver_routine'] + deviceptr_prefix = 'loki_devptr_prefix_' + driver.apply(FieldOffloadTransformation(devptr_prefix=deviceptr_prefix, + offload_index='i', + field_group_types=['state_type']), + role='driver', + targets=['kernel_routine']) + + calls = FindNodes(CallStatement).visit(driver.body) + assert not any(c for c in calls if c.name=='kernel_routine') + + # verify that no field offloads are generated + in_calls = [c for c in calls if 'get_device_data_rdonly' in c.name.name.lower()] + assert len(in_calls) == 0 + inout_calls = [c for c in calls if 'get_device_data_rdwr' in c.name.name.lower()] + assert len(inout_calls) == 0 + # verify that no field sync host calls are generated + sync_calls = [c for c in calls if 'sync_host_rdwr' in c.name.name.lower()] + assert len(sync_calls) == 0 + + # verify that data offload pragmas remain + pragmas = FindNodes(Pragma).visit(driver.body) + assert len(pragmas) == 2 + assert all(p.keyword=='loki' and p.content==c for p, c in zip(pragmas, ['data', 'end data'])) + + +@pytest.mark.parametrize('frontend', available_frontends()) +def test_field_offload_unknown_kernel(caplog, frontend, parkind_mod, field_module, tmp_path): + fother = """ + module another_module + implicit none + contains + subroutine another_kernel(nlon, nlev, a, b, c) + integer, intent(in) :: nlon, nlev + real, intent(in) :: a(nlon,nlev) + real, intent(inout) :: b(nlon,nlev) + real, intent(out) :: c(nlon,nlev) + integer :: i, j + end subroutine + end module + """ + fcode = """ + module driver_mod + use parkind1, only: jprb + use another_module, only: another_kernel + implicit none + + type state_type + real(kind=jprb), dimension(10,10), pointer :: a, b, c + class(field_3rb), pointer :: f_a, f_b, f_c + contains + procedure :: update_view => state_update_view + end type state_type + + contains + + subroutine state_update_view(self, idx) + class(state_type), intent(in) :: self + integer, intent(in) :: idx + end subroutine + + subroutine driver_routine(nlon, nlev, state) + integer, intent(in) :: nlon, nlev + type(state_type), intent(inout) :: state + integer :: i + + !$loki data + do i=1,nlev + call state%update_view(i) + call another_kernel(nlon, nlev, state%a, state%b, state%c) + end do + !$loki end data + + end subroutine driver_routine + end module driver_mod + """ + + Sourcefile.from_source(fother, frontend=frontend, xmods=[tmp_path]) + driver_mod = Sourcefile.from_source(fcode, frontend=frontend, xmods=[tmp_path])['driver_mod'] + driver = driver_mod['driver_routine'] + deviceptr_prefix = 'loki_devptr_prefix_' + + field_offload_trafo = FieldOffloadTransformation(devptr_prefix=deviceptr_prefix, + offload_index='i', + field_group_types=['state_type']) + caplog.clear() + with caplog.at_level(log_levels['ERROR']): + with pytest.raises(RuntimeError): + driver.apply(field_offload_trafo, role='driver', targets=['another_kernel']) + assert len(caplog.records) == 1 + assert ('[Loki] Data offload: Routine driver_routine has not been enriched '+ + 'in another_kernel') in caplog.records[0].message + + +@pytest.mark.parametrize('frontend', available_frontends()) +def test_field_offload_warnings(caplog, frontend, parkind_mod, field_module, tmp_path): + fother_state = """ + module state_type_mod + implicit none + type state_type2 + real, dimension(10,10), pointer :: a, b, c + contains + procedure :: update_view => state_update_view + end type state_type2 + + contains + + subroutine state_update_view(self, idx) + class(state_type2), intent(in) :: self + integer, intent(in) :: idx + end subroutine + end module + """ + fother_mod= """ + module another_module + implicit none + contains + subroutine another_kernel(nlon, nlev, a, b, c) + integer, intent(in) :: nlon, nlev + real, intent(in) :: a(nlon,nlev) + real, intent(inout) :: b(nlon,nlev) + real, intent(out) :: c(nlon,nlev) + integer :: i, j + end subroutine + end module + """ + fcode = """ + module driver_mod + use state_type_mod, only: state_type2 + use parkind1, only: jprb + use field_module, only: field_2rb, field_3rb + use another_module, only: another_kernel + + implicit none + + type state_type + real(kind=jprb), dimension(10,10), pointer :: a, b, c + class(field_3rb), pointer :: f_a, f_b, f_c + contains + procedure :: update_view => state_update_view + end type state_type + + contains + + subroutine state_update_view(self, idx) + class(state_type), intent(in) :: self + integer, intent(in) :: idx + end subroutine + + subroutine kernel_routine(nlon, nlev, a, b, c) + integer, intent(in) :: nlon, nlev + real(kind=jprb), intent(in) :: a(nlon,nlev) + real(kind=jprb), intent(inout) :: b(nlon,nlev) + real(kind=jprb), intent(out) :: c(nlon,nlev) + integer :: i, j + + do j=1, nlon + do i=1, nlev + b(i,j) = a(i,j) + 0.1 + c(i,j) = 0.1 + end do + end do + end subroutine kernel_routine + + subroutine driver_routine(nlon, nlev, state, state2) + integer, intent(in) :: nlon, nlev + type(state_type), intent(inout) :: state + type(state_type2), intent(inout) :: state2 + + integer :: i + real(kind=jprb) :: a(nlon,nlev) + real, pointer :: loki_devptr_prefix_state_b + + !$loki data + do i=1,nlev + call state%update_view(i) + call kernel_routine(nlon, nlev, a, state%b, state2%c) + end do + !$loki end data + + end subroutine driver_routine + end module driver_mod + """ + Sourcefile.from_source(fother_state, frontend=frontend, xmods=[tmp_path]) + Sourcefile.from_source(fother_mod, frontend=frontend, xmods=[tmp_path]) + driver_mod = Sourcefile.from_source(fcode, frontend=frontend, xmods=[tmp_path])['driver_mod'] + driver = driver_mod['driver_routine'] + deviceptr_prefix = 'loki_devptr_prefix_' + + field_offload_trafo = FieldOffloadTransformation(devptr_prefix=deviceptr_prefix, + offload_index='i', + field_group_types=['state_type']) + caplog.clear() + with caplog.at_level(log_levels['WARNING']): + driver.apply(field_offload_trafo, role='driver', targets=['kernel_routine']) + assert len(caplog.records) == 3 + assert (('[Loki] Data offload: Raw array object a encountered in' + +' driver_routine that is not wrapped by a Field API object') + in caplog.records[0].message) + assert ('[Loki] Data offload: The parent object state2 of type state_type2 is not in the' + + ' list of field wrapper types') in caplog.records[1].message + assert ('[Loki] Data offload: The routine driver_routine already has a' + + ' variable named loki_devptr_prefix_state_b') in caplog.records[2].message