Skip to content

Commit

Permalink
Removed unused functions and added more complete tests of data offloads
Browse files Browse the repository at this point in the history
  • Loading branch information
wertysas committed Nov 18, 2024
1 parent b7f0512 commit 7e53319
Show file tree
Hide file tree
Showing 4 changed files with 315 additions and 63 deletions.
37 changes: 14 additions & 23 deletions loki/transformations/data_offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
Transformer, pragma_regions_attached, get_pragma_parameters,
FindInlineCalls, SubstituteExpressions
)
from loki.logging import warning
from loki.logging import warning, error
from loki.tools import as_tuple, flatten, CaseInsensitiveDict, CaseInsensitiveDefaultDict
from loki.types import BasicType, DerivedType
from loki.transformations.parallel import (
Expand Down Expand Up @@ -54,6 +54,11 @@ def __init__(self, **kwargs):
self.assume_deviceptr = kwargs.get('assume_deviceptr', False)
self.assume_acc_mapped = kwargs.get('assume_acc_mapped', False)

if self.assume_deviceptr and self.assume_acc_mapped:
error("[Loki] Data offload: Can't assume both acc_mapped and " +
"non-mapped device pointers for device data offload")
raise RuntimeError

def transform_subroutine(self, routine, **kwargs):
"""
Apply the transformation to a `Subroutine` object.
Expand Down Expand Up @@ -152,7 +157,7 @@ def insert_data_offload_pragmas(self, routine, targets):
outargs = tuple(dict.fromkeys(outargs))
inoutargs = tuple(dict.fromkeys(inoutargs))

# Now geenerate the pre- and post pragmas (OpenACC)
# Now generate the pre- and post pragmas (OpenACC)
if self.assume_deviceptr:
offload_args = inargs + outargs + inoutargs
if offload_args:
Expand Down Expand Up @@ -973,14 +978,9 @@ def __init__(self, **kwargs):
def transform_subroutine(self, routine, **kwargs):
role = kwargs['role']
targets = as_tuple(kwargs.get('targets'), (None))
if role == 'kernel':
self.process_kernel(routine)
if role == 'driver':
self.process_driver(routine, targets)

def process_kernel(self, routine):
pass

def process_driver(self, driver, targets):
remove_field_api_view_updates(driver, self.field_group_types + tuple(s.upper() for s in self.field_group_types))
with pragma_regions_attached(driver):
Expand All @@ -1002,20 +1002,21 @@ def find_offload_variables(self, driver, calls):

for call in calls:
if call.routine is BasicType.DEFERRED:
warning(f'[Loki] Data offload: Routine {driver.name} has not been enriched ' +
error(f'[Loki] Data offload: Routine {driver.name} has not been enriched ' +
f'in {str(call.name).lower()}')
continue
raise RuntimeError
for param, arg in call.arg_iter():
if not isinstance(param, Array):
continue
try:
parent = arg.parent
if parent.type.dtype.name.lower() not in self.field_group_types:
warning(f'[Loki] The parent object {parent.name} of type ' +
warning(f'[Loki] Data offload: The parent object {parent.name} of type ' +
f'{parent.type.dtype} is not in the list of field wrapper types')
continue
except AttributeError:
warning(f'[Loki] Field data offload: Raw array object {arg.name} encountered in'
+ f'{driver.name} that is not wrapped by a Field API object')
warning(f'[Loki] Data offload: Raw array object {arg.name} encountered in'
+ f' {driver.name} that is not wrapped by a Field API object')
continue

if param.type.intent.lower() == 'in':
Expand Down Expand Up @@ -1049,7 +1050,7 @@ def _devptr_from_array(self, driver, a: sym.Array):
base_name = a.name if a.parent is None else '_'.join(a.name.split('%'))
devptr_name = self.deviceptr_prefix + base_name
if devptr_name in driver.variable_map:
warning(f'[Loki] Field data offload: The routine {driver.name} already has a' +
warning(f'[Loki] Data offload: The routine {driver.name} already has a ' +
f'variable named {devptr_name}')
devptr = sym.Variable(name=devptr_name, type=devptr_type, dimensions=shape)
return devptr
Expand All @@ -1066,16 +1067,6 @@ def _add_field_offload_calls(self, driver, region, offload_map):
update_map = {region: host_to_device + (region,) + device_to_host}
Transformer(update_map, inplace=True).visit(driver.body)

def _is_field_group_update(self, driver, call):
try:
*_, parent, _call_name = call.name.name.split('%')
parent = driver.variable_map.get(parent)
if parent is not None and parent.type.dtype.name.lower() in self.field_group_types:
return True
except ValueError:
return False
return False

def _get_field_ptr_from_view(self, field_view):
type_chain = field_view.name.split('%')
field_type_name = 'F_' + type_chain[-1]
Expand Down
12 changes: 1 addition & 11 deletions loki/transformations/parallel/field_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@

__all__ = [
'remove_field_api_view_updates', 'add_field_api_view_updates', 'get_field_type',
'field_new', 'field_delete', 'field_get_device_data', 'field_sync_host',
'FieldAPITransferType'
'field_get_device_data', 'field_sync_host', 'FieldAPITransferType'
]


Expand Down Expand Up @@ -182,15 +181,6 @@ def get_field_type(a: sym.Array) -> sym.DerivedType:
return field_type


def field_new(field_ptr, data, scope):
return ir.CallStatement(name=sym.ProcedureSymbol('FIELD_NEW', scope=scope),
arguments=(field_ptr,), kwarguments=(('DATA', data),))


def field_delete(field_ptr, scope):
return ir.CallStatement(name=sym.ProcedureSymbol('FIELD_DELETE', scope=scope),
arguments=(field_ptr,))


class FieldAPITransferType(Enum):
READ_ONLY = 1
Expand Down
14 changes: 13 additions & 1 deletion loki/transformations/parallel/tests/test_field_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
from loki.expression import symbols as sym
from loki.scope import Scope
from loki.transformations.parallel import (
remove_field_api_view_updates, add_field_api_view_updates, get_field_type
remove_field_api_view_updates, add_field_api_view_updates, get_field_type,
field_get_device_data, FieldAPITransferType
)
from loki.types import BasicType, SymbolAttributes
from loki.logging import WARNING
Expand Down Expand Up @@ -187,3 +188,14 @@ def generate_fields(types):
assert isinstance(field, sym.DerivedType) and field.name == field_name


def test_field_get_device_data():
scope = Scope()
fptr = sym.Variable(name='fptr_var')
dev_ptr = sym.Variable(name='data_var')
for fttype in FieldAPITransferType:
get_dev_data_call = field_get_device_data(fptr, dev_ptr, fttype, scope)
assert isinstance(get_dev_data_call, ir.CallStatement)
assert get_dev_data_call.name.parent == fptr
with pytest.raises(TypeError):
_ = field_get_device_data(fptr, dev_ptr, "none_transfer_type", scope)

Loading

0 comments on commit 7e53319

Please sign in to comment.