Skip to content

Commit

Permalink
Field Offload, added missing support for sliced arguments to calls
Browse files Browse the repository at this point in the history
  • Loading branch information
wertysas committed Nov 20, 2024
1 parent 67d1a30 commit 45c9e20
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 2 deletions.
4 changes: 2 additions & 2 deletions loki/transformations/data_offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -996,7 +996,7 @@ def in_pairs(self):
______
:any:`Array`
Original kernel call argument
:any: `Array`
:any:`Array`
Corresponding device pointer added by the transformation.
"""
for i, inarg in enumerate(self.inargs):
Expand Down Expand Up @@ -1143,7 +1143,7 @@ def _replace_kernel_args(self, driver, kernel_calls, offload_map):
change_map = {}
offload_idx_expr = driver.variable_map[self.offload_index]
for arg, devptr in chain(offload_map.in_pairs, offload_map.inout_pairs, offload_map.out_pairs):
dims = (sym.RangeIndex((None, None)),) * (len(devptr.shape)-1) + (offload_idx_expr,)
dims = arg.dimensions + (offload_idx_expr,)
change_map[arg] = devptr.clone(dimensions=dims)
arg_transformer = SubstituteExpressions(change_map, inplace=True)
for call in kernel_calls:
Expand Down
77 changes: 77 additions & 0 deletions loki/transformations/tests/test_data_offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,12 @@ def fixture_field_module(tmp_path, frontend):
contains
procedure :: update_view
end type field_3rb
type field_4rb
real, pointer :: f_ptr(:,:,:)
contains
procedure :: update_view
end type field_4rb
contains
subroutine update_view(self, idx)
Expand Down Expand Up @@ -947,6 +953,77 @@ def test_field_offload(frontend, parkind_mod, field_module, tmp_path):
assert devptr.name in (arg.name for arg in kernel_call.arguments)


@pytest.mark.parametrize('frontend', available_frontends())
def test_field_offload_slices(frontend, parkind_mod, field_module, tmp_path):
fcode = """
module driver_mod
use parkind1, only: jprb
use field_module, only: field_4rb
implicit none
type state_type
real(kind=jprb), dimension(10,10,10), pointer :: a, b, c
class(field_4rb), pointer :: f_a, f_b, f_c
contains
procedure :: update_view => state_update_view
end type state_type
contains
subroutine state_update_view(self, idx)
class(state_type), intent(in) :: self
integer, intent(in) :: idx
end subroutine
subroutine kernel_routine(nlon, nlev, a, b, c)
integer, intent(in) :: nlon, nlev
real(kind=jprb), intent(in) :: a(nlon,nlev,nlon)
real(kind=jprb), intent(inout) :: b(nlon,nlev)
real(kind=jprb), intent(out) :: c(nlon)
integer :: i, j
end subroutine kernel_routine
subroutine driver_routine(nlon, nlev, state)
integer, intent(in) :: nlon, nlev
type(state_type), intent(inout) :: state
integer :: i
!$loki data
do i=1,nlev
call kernel_routine(nlon, nlev, state%a(:,:,1), state%b(:,1,1), state%c(1,1,1))
end do
!$loki end data
end subroutine driver_routine
end module driver_mod
"""
driver_mod = Sourcefile.from_source(fcode, frontend=frontend, xmods=[tmp_path])['driver_mod']
driver = driver_mod['driver_routine']
deviceptr_prefix = 'loki_devptr_prefix_'
driver.apply(FieldOffloadTransformation(devptr_prefix=deviceptr_prefix,
offload_index='i',
field_group_types=['state_type']),
role='driver',
targets=['kernel_routine'])

calls = FindNodes(CallStatement).visit(driver.body)
kernel_call = next(c for c in calls if c.name=='kernel_routine')
# verify that new pointer variables are created and used in driver calls
for var, rank in zip(['state_a', 'state_b', 'state_c'], [4, 3, 2]):
name = deviceptr_prefix + var
assert name in driver.variable_map
devptr = driver.variable_map[name]
assert isinstance(devptr, sym.Array)
assert len(devptr.shape) == 4
assert devptr.name in (arg.name for arg in kernel_call.arguments)
arg = next(arg for arg in kernel_call.arguments if devptr.name in arg.name)
print(arg.dimensions)
assert arg.dimensions == ((sym.RangeIndex((None,None)),)*(rank-2) +
(sym.IntLiteral(1),)*(5-rank) +
(sym.Scalar(name='i'),))




@pytest.mark.parametrize('frontend', available_frontends())
def test_field_offload_multiple_calls(frontend, parkind_mod, field_module, tmp_path):
fcode = """
Expand Down

0 comments on commit 45c9e20

Please sign in to comment.