Skip to content

Commit

Permalink
Removed unused find_array_arg utility and FieldOffload clean
Browse files Browse the repository at this point in the history
  • Loading branch information
wertysas committed Nov 11, 2024
1 parent 561424e commit 4acd385
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 108 deletions.
55 changes: 1 addition & 54 deletions loki/transformations/data_offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -920,57 +920,6 @@ def _append_routine_arguments(self, routine, item):
routine.arguments += tuple(sorted(new_arguments, key=lambda symbol: symbol.name))


################################################################################
# Field API helper routines
################################################################################

def find_array_arguments(routine, calls):
"""
Finds all arguments and sorts them by intents
Parameters
----------
routine : `Subroutine`
Subroutine to apply this transformation to.
calls : list of `CallStatement`
Calls to extract arguments from.
This transformation will only apply at the ``'driver'`` level.
Returns
-------
inargs : tuple of arguments that are only inargs
inoutargs : tuples of arguments that are both in and out, or inout
outargs : tuples of arguments that are only outargs
"""
inargs = ()
inoutargs = ()
outargs = ()

for call in calls:
if call.routine is BasicType.DEFERRED:
warning(f'[Loki] Data offload: Routine {routine.name} has not been enriched ' +
f'in {str(call.name).lower()}')
continue
for param, arg in call.arg_iter():
if isinstance(param, Array) and param.type.intent.lower() == 'in':
inargs += (arg, )
if isinstance(param, Array) and param.type.intent.lower() == 'inout':
inoutargs += (arg, )
if isinstance(param, Array) and param.type.intent.lower() == 'out':
outargs += (arg, )

inoutargs += tuple(v for v in inargs if v in outargs)
inargs = tuple(v for v in inargs if v not in inoutargs)
outargs = tuple(v for v in outargs if v not in inoutargs)

# Filter for duplicates TODO: What if we pass different slices of same array!?
inargs = tuple(set(inargs))
inoutargs = tuple(set(inoutargs))
outargs = tuple(set(outargs))
return inargs, inoutargs, outargs


def find_target_calls(region, targets):
"""Returns a list of all calls to targets inside the region
Expand Down Expand Up @@ -1030,19 +979,18 @@ def process_kernel(self, routine):
pass

def process_driver(self, driver, targets):
remove_field_api_view_updates(driver, self.field_group_types + tuple(s.upper() for s in self.field_group_types))
with pragma_regions_attached(driver):
for region in FindNodes(PragmaRegion).visit(driver.body):
# Only work on active `!$loki data` regions
if not DataOffloadTransformation._is_active_loki_data_region(region, targets):
continue
# remove_field_api_view_updates(driver, self.field_group_types) # FIXME: if called here, driver.body will not be updated in subsequent routines
kernel_calls = find_target_calls(region, targets)
offload_variables = self.find_offload_variables(driver, kernel_calls)
device_ptrs = self._declare_device_ptrs(driver, offload_variables)
offload_map = self.FieldPointerMap(device_ptrs, *offload_variables)
self._add_field_offload_calls(driver, region, offload_map)
self._replace_kernel_args(driver, kernel_calls, offload_map)
remove_field_api_view_updates(driver, self.field_group_types) # if called here it works

def find_offload_variables(self, driver, calls):
inargs = ()
Expand Down Expand Up @@ -1078,7 +1026,6 @@ def find_offload_variables(self, driver, calls):
inargs = tuple(v for v in inargs if v not in inoutargs)
outargs = tuple(v for v in outargs if v not in inoutargs)

# Filter for duplicates TODO: What if we pass different slices of same array!?
inargs = tuple(set(inargs))
inoutargs = tuple(set(inoutargs))
outargs = tuple(set(outargs))
Expand Down
54 changes: 0 additions & 54 deletions loki/transformations/tests/test_data_offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
DataOffloadTransformation, GlobalVariableAnalysis,
GlobalVarOffloadTransformation, GlobalVarHoistTransformation, FieldOffloadTransformation
)
from loki.transformations.data_offload import find_array_arguments, find_target_calls


@pytest.fixture(scope='module', name='here')
Expand Down Expand Up @@ -809,59 +808,6 @@ def test_transformation_global_var_derived_type_hoist(here, config, frontend, ho
assert kernel.variable_map['p0'].type.dtype.name == 'point'



@pytest.mark.parametrize('frontend', available_frontends())
def test_find_array_arguments(frontend):
fcode_driver = """
SUBROUTINE driver_routine(nlon, nlev, a, b, c, d, e, f)
INTEGER, INTENT(INOUT) :: nlon, nlev
REAL, INTENT(INOUT) :: a(nlon,nlev)
REAL, INTENT(INOUT) :: b(nlon)
REAL, INTENT(INOUT) :: c(nlon,nlev,nlon)
REAL, INTENT(INOUT) :: d
REAL, INTENT(INOUT) :: e
REAL, INTENT(INOUT) :: f
INTEGER :: nlon_copy, nlev_copy
REAL :: a_copy(nlon,nlev)
REAL :: b_copy(nlon)
REAL :: c_copy(nlon,nlev,nlon)
REAL :: d_copy
REAL :: e_copy
REAL :: f_copy
call kernel_routine(nlon_copy, nlev_copy, a_copy, b_copy, c_copy, d_copy, e_copy, f_copy)
END SUBROUTINE driver_routine
"""
fcode_kernel = """
SUBROUTINE kernel_routine(nlon, nlev, a, b, c, d, e, f)
INTEGER, INTENT(IN) :: nlon, nlev
REAL, INTENT(IN) :: a(nlon,nlev)
REAL, INTENT(INOUT) :: b(nlon)
REAL, INTENT(OUT) :: c(nlon,nlev,nlon)
REAL, INTENT(IN) :: d
REAL, INTENT(IN) :: e
REAL, INTENT(IN) :: f
do j=1, nlon
do i=1, nlev
b_copy(i) = a(i,j) + 0.1
c(i,j,i) = 0.1
end do
end do
END SUBROUTINE kernel_routine
"""
driver = Sourcefile.from_source(fcode_driver, frontend=frontend)['driver_routine']
kernel = Sourcefile.from_source(fcode_kernel, frontend=frontend)['kernel_routine']
driver.enrich(kernel)
targets = find_target_calls(driver.body, [kernel.name])
in_vars, inout_vars, out_vars = find_array_arguments(driver, targets)

assert len(in_vars) == 1 and in_vars[0].name == 'a_copy'
assert len(inout_vars) == 1 and inout_vars[0].name == 'b_copy'
assert len(out_vars) == 1 and out_vars[0].name == 'c_copy'


@pytest.mark.parametrize('frontend', available_frontends())
def test_field_offload(frontend):
fcode = """
Expand Down

0 comments on commit 4acd385

Please sign in to comment.