From 45c9e2072d7bcea867dd0327d7feaca0b6a152e2 Mon Sep 17 00:00:00 2001 From: Johan Ericsson Date: Wed, 20 Nov 2024 12:08:00 +0100 Subject: [PATCH] Field Offload, added missing support for sliced arguments to calls --- loki/transformations/data_offload.py | 4 +- .../tests/test_data_offload.py | 77 +++++++++++++++++++ 2 files changed, 79 insertions(+), 2 deletions(-) diff --git a/loki/transformations/data_offload.py b/loki/transformations/data_offload.py index 5f963f25a..fb1b4babd 100644 --- a/loki/transformations/data_offload.py +++ b/loki/transformations/data_offload.py @@ -996,7 +996,7 @@ def in_pairs(self): ______ :any:`Array` Original kernel call argument - :any: `Array` + :any:`Array` Corresponding device pointer added by the transformation. """ for i, inarg in enumerate(self.inargs): @@ -1143,7 +1143,7 @@ def _replace_kernel_args(self, driver, kernel_calls, offload_map): change_map = {} offload_idx_expr = driver.variable_map[self.offload_index] for arg, devptr in chain(offload_map.in_pairs, offload_map.inout_pairs, offload_map.out_pairs): - dims = (sym.RangeIndex((None, None)),) * (len(devptr.shape)-1) + (offload_idx_expr,) + dims = arg.dimensions + (offload_idx_expr,) change_map[arg] = devptr.clone(dimensions=dims) arg_transformer = SubstituteExpressions(change_map, inplace=True) for call in kernel_calls: diff --git a/loki/transformations/tests/test_data_offload.py b/loki/transformations/tests/test_data_offload.py index 4dbf915c5..495ef72fe 100644 --- a/loki/transformations/tests/test_data_offload.py +++ b/loki/transformations/tests/test_data_offload.py @@ -67,6 +67,12 @@ def fixture_field_module(tmp_path, frontend): contains procedure :: update_view end type field_3rb + + type field_4rb + real, pointer :: f_ptr(:,:,:) + contains + procedure :: update_view + end type field_4rb contains subroutine update_view(self, idx) @@ -947,6 +953,77 @@ def test_field_offload(frontend, parkind_mod, field_module, tmp_path): assert devptr.name in (arg.name for arg in kernel_call.arguments) +@pytest.mark.parametrize('frontend', available_frontends()) +def test_field_offload_slices(frontend, parkind_mod, field_module, tmp_path): + fcode = """ + module driver_mod + use parkind1, only: jprb + use field_module, only: field_4rb + implicit none + + type state_type + real(kind=jprb), dimension(10,10,10), pointer :: a, b, c + class(field_4rb), pointer :: f_a, f_b, f_c + contains + procedure :: update_view => state_update_view + end type state_type + + contains + + subroutine state_update_view(self, idx) + class(state_type), intent(in) :: self + integer, intent(in) :: idx + end subroutine + + subroutine kernel_routine(nlon, nlev, a, b, c) + integer, intent(in) :: nlon, nlev + real(kind=jprb), intent(in) :: a(nlon,nlev,nlon) + real(kind=jprb), intent(inout) :: b(nlon,nlev) + real(kind=jprb), intent(out) :: c(nlon) + integer :: i, j + end subroutine kernel_routine + + subroutine driver_routine(nlon, nlev, state) + integer, intent(in) :: nlon, nlev + type(state_type), intent(inout) :: state + integer :: i + !$loki data + do i=1,nlev + call kernel_routine(nlon, nlev, state%a(:,:,1), state%b(:,1,1), state%c(1,1,1)) + end do + !$loki end data + + end subroutine driver_routine + end module driver_mod + """ + driver_mod = Sourcefile.from_source(fcode, frontend=frontend, xmods=[tmp_path])['driver_mod'] + driver = driver_mod['driver_routine'] + deviceptr_prefix = 'loki_devptr_prefix_' + driver.apply(FieldOffloadTransformation(devptr_prefix=deviceptr_prefix, + offload_index='i', + field_group_types=['state_type']), + role='driver', + targets=['kernel_routine']) + + calls = FindNodes(CallStatement).visit(driver.body) + kernel_call = next(c for c in calls if c.name=='kernel_routine') + # verify that new pointer variables are created and used in driver calls + for var, rank in zip(['state_a', 'state_b', 'state_c'], [4, 3, 2]): + name = deviceptr_prefix + var + assert name in driver.variable_map + devptr = driver.variable_map[name] + assert isinstance(devptr, sym.Array) + assert len(devptr.shape) == 4 + assert devptr.name in (arg.name for arg in kernel_call.arguments) + arg = next(arg for arg in kernel_call.arguments if devptr.name in arg.name) + print(arg.dimensions) + assert arg.dimensions == ((sym.RangeIndex((None,None)),)*(rank-2) + + (sym.IntLiteral(1),)*(5-rank) + + (sym.Scalar(name='i'),)) + + + + @pytest.mark.parametrize('frontend', available_frontends()) def test_field_offload_multiple_calls(frontend, parkind_mod, field_module, tmp_path): fcode = """