Skip to content

Commit

Permalink
Removed unused functions and added more complete tests of data offloads
Browse files Browse the repository at this point in the history
  • Loading branch information
wertysas committed Nov 18, 2024
1 parent b7f0512 commit 6b9b2e4
Show file tree
Hide file tree
Showing 4 changed files with 161 additions and 45 deletions.
28 changes: 7 additions & 21 deletions loki/transformations/data_offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
Transformer, pragma_regions_attached, get_pragma_parameters,
FindInlineCalls, SubstituteExpressions
)
from loki.logging import warning
from loki.logging import warning, error
from loki.tools import as_tuple, flatten, CaseInsensitiveDict, CaseInsensitiveDefaultDict
from loki.types import BasicType, DerivedType
from loki.transformations.parallel import (
Expand Down Expand Up @@ -152,7 +152,7 @@ def insert_data_offload_pragmas(self, routine, targets):
outargs = tuple(dict.fromkeys(outargs))
inoutargs = tuple(dict.fromkeys(inoutargs))

# Now geenerate the pre- and post pragmas (OpenACC)
# Now generate the pre- and post pragmas (OpenACC)
if self.assume_deviceptr:
offload_args = inargs + outargs + inoutargs
if offload_args:
Expand Down Expand Up @@ -973,14 +973,9 @@ def __init__(self, **kwargs):
def transform_subroutine(self, routine, **kwargs):
role = kwargs['role']
targets = as_tuple(kwargs.get('targets'), (None))
if role == 'kernel':
self.process_kernel(routine)
if role == 'driver':
self.process_driver(routine, targets)

def process_kernel(self, routine):
pass

def process_driver(self, driver, targets):
remove_field_api_view_updates(driver, self.field_group_types + tuple(s.upper() for s in self.field_group_types))
with pragma_regions_attached(driver):
Expand All @@ -1002,9 +997,9 @@ def find_offload_variables(self, driver, calls):

for call in calls:
if call.routine is BasicType.DEFERRED:
warning(f'[Loki] Data offload: Routine {driver.name} has not been enriched ' +
error(f'[Loki] Data offload: Routine {driver.name} has not been enriched ' +
f'in {str(call.name).lower()}')
continue
raise RuntimeError
for param, arg in call.arg_iter():
if not isinstance(param, Array):
continue
Expand All @@ -1013,9 +1008,10 @@ def find_offload_variables(self, driver, calls):
if parent.type.dtype.name.lower() not in self.field_group_types:
warning(f'[Loki] The parent object {parent.name} of type ' +
f'{parent.type.dtype} is not in the list of field wrapper types')
continue
except AttributeError:
warning(f'[Loki] Field data offload: Raw array object {arg.name} encountered in'
+ f'{driver.name} that is not wrapped by a Field API object')
+ f' {driver.name} that is not wrapped by a Field API object')
continue

if param.type.intent.lower() == 'in':
Expand Down Expand Up @@ -1049,7 +1045,7 @@ def _devptr_from_array(self, driver, a: sym.Array):
base_name = a.name if a.parent is None else '_'.join(a.name.split('%'))
devptr_name = self.deviceptr_prefix + base_name
if devptr_name in driver.variable_map:
warning(f'[Loki] Field data offload: The routine {driver.name} already has a' +
warning(f'[Loki] Field data offload: The routine {driver.name} already has a ' +
f'variable named {devptr_name}')
devptr = sym.Variable(name=devptr_name, type=devptr_type, dimensions=shape)
return devptr
Expand All @@ -1066,16 +1062,6 @@ def _add_field_offload_calls(self, driver, region, offload_map):
update_map = {region: host_to_device + (region,) + device_to_host}
Transformer(update_map, inplace=True).visit(driver.body)

def _is_field_group_update(self, driver, call):
try:
*_, parent, _call_name = call.name.name.split('%')
parent = driver.variable_map.get(parent)
if parent is not None and parent.type.dtype.name.lower() in self.field_group_types:
return True
except ValueError:
return False
return False

def _get_field_ptr_from_view(self, field_view):
type_chain = field_view.name.split('%')
field_type_name = 'F_' + type_chain[-1]
Expand Down
12 changes: 1 addition & 11 deletions loki/transformations/parallel/field_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@

__all__ = [
'remove_field_api_view_updates', 'add_field_api_view_updates', 'get_field_type',
'field_new', 'field_delete', 'field_get_device_data', 'field_sync_host',
'FieldAPITransferType'
'field_get_device_data', 'field_sync_host', 'FieldAPITransferType'
]


Expand Down Expand Up @@ -182,15 +181,6 @@ def get_field_type(a: sym.Array) -> sym.DerivedType:
return field_type


def field_new(field_ptr, data, scope):
return ir.CallStatement(name=sym.ProcedureSymbol('FIELD_NEW', scope=scope),
arguments=(field_ptr,), kwarguments=(('DATA', data),))


def field_delete(field_ptr, scope):
return ir.CallStatement(name=sym.ProcedureSymbol('FIELD_DELETE', scope=scope),
arguments=(field_ptr,))


class FieldAPITransferType(Enum):
READ_ONLY = 1
Expand Down
14 changes: 13 additions & 1 deletion loki/transformations/parallel/tests/test_field_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
from loki.expression import symbols as sym
from loki.scope import Scope
from loki.transformations.parallel import (
remove_field_api_view_updates, add_field_api_view_updates, get_field_type
remove_field_api_view_updates, add_field_api_view_updates, get_field_type,
field_get_device_data, FieldAPITransferType
)
from loki.types import BasicType, SymbolAttributes
from loki.logging import WARNING
Expand Down Expand Up @@ -187,3 +188,14 @@ def generate_fields(types):
assert isinstance(field, sym.DerivedType) and field.name == field_name


def test_field_get_device_data():
scope = Scope()
fptr = sym.Variable(name='fptr_var')
dev_ptr = sym.Variable(name='data_var')
for fttype in FieldAPITransferType:
get_dev_data_call = field_get_device_data(fptr, dev_ptr, fttype, scope)
assert isinstance(get_dev_data_call, ir.CallStatement)
assert get_dev_data_call.name.parent == fptr
with pytest.raises(TypeError):
_ = field_get_device_data(fptr, dev_ptr, "none_transfer_type", scope)

152 changes: 140 additions & 12 deletions loki/transformations/tests/test_data_offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@
import pytest

from loki import (
Sourcefile, Scheduler, FindInlineCalls
Sourcefile, Scheduler, FindInlineCalls, warning
)
from loki.frontend import available_frontends, OMNI
from loki.logging import WARNING
from loki.ir import (
FindNodes, Pragma, PragmaRegion, Loop, CallStatement, Import,
pragma_regions_attached, get_pragma_parameters
Expand All @@ -22,7 +23,6 @@
GlobalVarOffloadTransformation, GlobalVarHoistTransformation, FieldOffloadTransformation
)


@pytest.fixture(scope='module', name='here')
def fixture_here():
return Path(__file__).parent
Expand All @@ -46,7 +46,8 @@ def fixture_config():

@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('assume_deviceptr', [True, False])
def test_data_offload_region_openacc(frontend, assume_deviceptr):
@pytest.mark.parametrize('assume_acc_mapped', [True, False])
def test_data_offload_region_openacc(frontend, assume_deviceptr, assume_acc_mapped):
"""
Test the creation of a simple device data offload region
(`!$acc update`) from a `!$loki data` region with a single
Expand Down Expand Up @@ -86,7 +87,9 @@ def test_data_offload_region_openacc(frontend, assume_deviceptr):
kernel = Sourcefile.from_source(fcode_kernel, frontend=frontend)['kernel_routine']
driver.enrich(kernel)

driver.apply(DataOffloadTransformation(assume_deviceptr=assume_deviceptr), role='driver',
driver.apply(DataOffloadTransformation(assume_deviceptr=assume_deviceptr,
assume_acc_mapped=assume_acc_mapped),
role='driver',
targets=['kernel_routine'])

pragmas = FindNodes(Pragma).visit(driver.body)
Expand All @@ -96,6 +99,10 @@ def test_data_offload_region_openacc(frontend, assume_deviceptr):
assert 'deviceptr' in pragmas[0].content
params = get_pragma_parameters(pragmas[0], only_loki_pragmas=False)
assert all(var in params['deviceptr'] for var in ('a', 'b', 'c'))
elif assume_acc_mapped:
assert 'present' in pragmas[0].content
params = get_pragma_parameters(pragmas[0], only_loki_pragmas=False)
assert all(var in params['present'] for var in ('a', 'b', 'c'))
else:
transformed = driver.to_fortran()
assert 'copyin( a )' in transformed
Expand Down Expand Up @@ -856,11 +863,11 @@ def test_field_offload(frontend):
end module driver_mod
"""

driver_mod = Sourcefile.from_source(fcode)['driver_mod']
driver_mod = Sourcefile.from_source(fcode, frontend=frontend)['driver_mod']
driver = driver_mod['driver_routine']
deviceptr_prefix = 'loki_devptr_prefix_'
driver.apply(FieldOffloadTransformation(devptr_prefix=deviceptr_prefix,
offload_index='i',
offload_index='i',
field_group_types=['state_type']),
role='driver',
targets=['kernel_routine'])
Expand Down Expand Up @@ -943,17 +950,17 @@ def test_field_offload_multiple_calls(frontend):
end module driver_mod
"""

driver_mod = Sourcefile.from_source(fcode)['driver_mod']
driver_mod = Sourcefile.from_source(fcode, frontend=frontend)['driver_mod']
driver = driver_mod['driver_routine']
deviceptr_prefix = 'loki_devptr_prefix_'
driver.apply(FieldOffloadTransformation(devptr_prefix=deviceptr_prefix,
offload_index='i',
offload_index='i',
field_group_types=['state_type']),
role='driver',
targets=['kernel_routine'])
calls = FindNodes(CallStatement).visit(driver.body)
kernel_calls = [c for c in calls if c.name=='kernel_routine']

# verify that field offloads are generated properly
in_calls = [c for c in calls if 'get_device_data_rdonly' in c.name.name.lower()]
assert len(in_calls) == 1
Expand All @@ -967,7 +974,7 @@ def test_field_offload_multiple_calls(frontend):
pragmas = FindNodes(Pragma).visit(driver.body)
assert len(pragmas) == 2
assert all(p.keyword=='loki' and p.content==c for p, c in zip(pragmas, ['data', 'end data']))

# verify that new pointer variables are created and used in driver calls
for var in ['state_a', 'state_b', 'state_c']:
name = deviceptr_prefix + var
Expand Down Expand Up @@ -1028,11 +1035,11 @@ def test_field_offload_no_targets(frontend):
end module driver_mod
"""

driver_mod = Sourcefile.from_source(fcode)['driver_mod']
driver_mod = Sourcefile.from_source(fcode, frontend=frontend)['driver_mod']
driver = driver_mod['driver_routine']
deviceptr_prefix = 'loki_devptr_prefix_'
driver.apply(FieldOffloadTransformation(devptr_prefix=deviceptr_prefix,
offload_index='i',
offload_index='i',
field_group_types=['state_type']),
role='driver',
targets=['kernel_routine'])
Expand All @@ -1054,3 +1061,124 @@ def test_field_offload_no_targets(frontend):
assert len(pragmas) == 2
assert all(p.keyword=='loki' and p.content==c for p, c in zip(pragmas, ['data', 'end data']))


@pytest.mark.parametrize('frontend', available_frontends())
def test_field_offload_unknown_kernel(caplog, frontend):
fcode = """
module driver_mod
use parkind1, only: jprb
use another_module, only: another_kernel
implicit none
type state_type
real(kind=jprb), dimension(10,10), pointer :: a, b, c
class(field_3rb), pointer :: f_a, f_b, f_c
end type state_type
contains
subroutine driver_routine(nlon, nlev, state, state2)
integer, intent(in) :: nlon, nlev
type(state_type), intent(inout) :: state
integer :: i
!$loki data
do i=1,nlev
call state%update_view(i)
call another_kernel(nlon, nlev, state%a, state%b, state%c)
end do
!$loki end data
end subroutine driver_routine
end module driver_mod
"""

driver_mod = Sourcefile.from_source(fcode, frontend=frontend)['driver_mod']
driver = driver_mod['driver_routine']
deviceptr_prefix = 'loki_devptr_prefix_'

with caplog.at_level(WARNING):
with pytest.raises(RuntimeError):
driver.apply(FieldOffloadTransformation(devptr_prefix=deviceptr_prefix,
offload_index='i',
field_group_types=['state_type']),
role='driver',
targets=['another_kernel'])
assert len(caplog.records) == 1
assert ('[Loki] Data offload: Routine driver_routine has not been enriched '+
'in another_kernel') in caplog.records[0].message


@pytest.mark.parametrize('frontend', available_frontends())
def test_field_offload_warnings(caplog, frontend):
fcode = """
module driver_mod
use state_type_mod, only: state_type2
use parkind1, only: jprb
use field_module, only: field_2rb, field_3rb
use another_module, only: another_kernel
implicit none
type state_type
real(kind=jprb), dimension(10,10), pointer :: a, b, c
class(field_3rb), pointer :: f_a, f_b, f_c
end type state_type
contains
subroutine kernel_routine(nlon, nlev, a, b, c)
integer, intent(in) :: nlon, nlev
real(kind=jprb), intent(in) :: a(nlon,nlev)
real(kind=jprb), intent(inout) :: b(nlon,nlev)
real(kind=jprb), intent(out) :: c(nlon,nlev)
integer :: i, j
do j=1, nlon
do i=1, nlev
b(i,j) = a(i,j) + 0.1
c(i,j) = 0.1
end do
end do
end subroutine kernel_routine
subroutine driver_routine(nlon, nlev, state, state2)
integer, intent(in) :: nlon, nlev
type(state_type), intent(inout) :: state
type(state_type2), intent(inout) :: state2
integer :: i
real(kind=jprb) :: a(nlon,nlev)
real, pointer :: loki_devptr_prefix_state_b
!$loki data
do i=1,nlev
call state%update_view(i)
call kernel_routine(nlon, nlev, a, state%b, state2%c)
end do
!$loki end data
end subroutine driver_routine
end module driver_mod
"""

driver_mod = Sourcefile.from_source(fcode, frontend=frontend)['driver_mod']
driver = driver_mod['driver_routine']
deviceptr_prefix = 'loki_devptr_prefix_'

with caplog.at_level(WARNING):
driver.apply(FieldOffloadTransformation(devptr_prefix=deviceptr_prefix,
offload_index='i',
field_group_types=['state_type']),
role='driver',
targets=['kernel_routine'])
assert len(caplog.records) == 3

assert (('[Loki] Field data offload: Raw array object a encountered in'
+' driver_routine that is not wrapped by a Field API object')
in caplog.records[0].message)
assert ('[Loki] The parent object state2 of type state_type2 is not in the' +
' list of field wrapper types') in caplog.records[1].message
assert ('[Loki] Field data offload: The routine driver_routine already has a' +
' variable named loki_devptr_prefix_state_b') in caplog.records[2].message

0 comments on commit 6b9b2e4

Please sign in to comment.