Skip to content

Commit

Permalink
Merge pull request #424 from ecmwf-ifs/nams-continued-effort-cuda-tra…
Browse files Browse the repository at this point in the history
…nspilation

Continued: F2C/CUDA transpilation
  • Loading branch information
reuterbal authored Nov 19, 2024
2 parents 97d0943 + c018c04 commit b88920c
Show file tree
Hide file tree
Showing 10 changed files with 182 additions and 12 deletions.
2 changes: 2 additions & 0 deletions loki/backend/cudagen.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ def _subroutine_declaration(self, o, **kwargs):
prefix = ''
if o.prefix and "global" in o.prefix[0].lower():
prefix = '__global__ '
if o.prefix and "device" in o.prefix[0].lower():
prefix = '__device__ '
if o.is_function:
return_type = self.symgen.intrinsic_type_mapper(o.return_type)
else:
Expand Down
36 changes: 36 additions & 0 deletions loki/expression/symbols.py
Original file line number Diff line number Diff line change
Expand Up @@ -1420,6 +1420,42 @@ def clone(self, **kwargs):
kw_parameters = kwargs.get('kw_parameters', self.kw_parameters)
return InlineCall(function, parameters, kw_parameters)

def _sort_kwarguments(self):
"""
Helper routine to sort the kwarguments/kw_parameters according to the order of the
arguments (``self.routine.arguments``)`.
"""
routine = self.routine
assert routine is not BasicType.DEFERRED
kwargs = CaseInsensitiveDict(self.kwarguments)
r_arg_names = [arg.name for arg in routine.arguments if arg.name in kwargs]
new_kwarguments = tuple((arg_name, kwargs[arg_name]) for arg_name in r_arg_names)
return new_kwarguments

def is_kwargs_order_correct(self):
"""
Check whether kwarguments/kw_parameters are correctly ordered
in respect to the arguments (``self.routine.arguments``).
"""
return self.kwarguments == self._sort_kwarguments()

def clone_with_sorted_kwargs(self):
"""
Sort and update the kwarguments/kw_parameters according to the order of the
arguments (``self.routine.arguments``) and return the
conveted clone/copy of the inline call.
"""
new_kwarguments = self._sort_kwarguments()
return self.clone(kw_parameters=new_kwarguments)

def clone_with_kwargs_as_args(self):
"""
Convert all kwarguments/kw_parameters to arguments and
return the converted clone/copy of the inline call.
"""
new_kwarguments = self._sort_kwarguments()
new_args = tuple(arg[1] for arg in new_kwarguments)
return self.clone(parameters=self.arguments + new_args, kw_parameters=())

class Cast(pmbl.Call):
"""
Expand Down
51 changes: 51 additions & 0 deletions loki/expression/tests/test_expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -662,6 +662,57 @@ def test_no_arg_inline_call(frontend, tmp_path):
assert isinstance(assignment.rhs.function, sym.ProcedureSymbol)


@pytest.mark.parametrize('frontend', available_frontends())
def test_kwargs_inline_call(frontend, tmp_path):
"""
Test inline call with kwargs and correct sorting as well
as correct conversion to args.
"""
fcode_routine = """
subroutine my_kwargs_routine(var, v_a, v_b, v_c, v_d)
implicit none
integer, intent(out) :: var
integer, intent(in) :: v_a, v_b, v_c, v_d
var = my_kwargs_func(c=v_c, b=v_b, a=v_a, d=v_d)
contains
function my_kwargs_func(a, b, c, d)
integer, intent(in) :: a, b, c, d
integer :: my_kwargs_func
my_kwargs_func = a - b - c - d
end function my_kwargs_func
end subroutine my_kwargs_routine
"""
# Test the original implementation
filepath = tmp_path/(f'orig_expression_kwargs_call_{frontend}.f90')
routine = Subroutine.from_source(fcode_routine, frontend=frontend, xmods=[tmp_path])
function = jit_compile(routine, filepath=filepath, objname='my_kwargs_routine')
res_orig = function(100, 10, 5, 2)
assert res_orig == 83

# Sort the kwargs and test the transformed code
inline_call = list(FindInlineCalls().visit(routine.body))[0]
call_map = {inline_call: inline_call.clone_with_sorted_kwargs()}
routine.body = SubstituteExpressions(call_map).visit(routine.body)
inline_call = list(FindInlineCalls().visit(routine.body))[0]
assert inline_call.is_kwargs_order_correct()
assert not inline_call.arguments
assert inline_call.kwarguments == (('a', 'v_a'), ('b', 'v_b'), ('c', 'v_c'), ('d', 'v_d'))
filepath = tmp_path/(f'sorted_expression_kwargs_call_{frontend}.f90')
function = jit_compile(routine, filepath=filepath, objname='my_kwargs_routine')
res_sorted = function(100, 10, 5, 2)
assert res_sorted == 83

# Convert kwargs to args and test the transformed code
call_map = {inline_call: inline_call.clone_with_kwargs_as_args()}
routine.body = SubstituteExpressions(call_map).visit(routine.body)
inline_call = list(FindInlineCalls().visit(routine.body))[0]
assert not inline_call.kwarguments
filepath = tmp_path/(f'converted_expression_kwargs_call_{frontend}.f90')
function = jit_compile(routine, filepath=filepath, objname='my_kwargs_routine')
res_args = function(100, 10, 5, 2)
assert res_args == 83


@pytest.mark.parametrize('frontend', available_frontends())
def test_inline_call_derived_type_arguments(frontend, tmp_path):
"""
Expand Down
2 changes: 1 addition & 1 deletion loki/ir/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1110,7 +1110,7 @@ def _sort_kwarguments(self):
new_kwarguments = tuple((arg_name, kwargs[arg_name]) for arg_name in r_arg_names)
return new_kwarguments

def check_kwarguments_order(self):
def is_kwargs_order_correct(self):
"""
Check whether kwarguments are correctly ordered
in respect to the arguments (``self.routine.arguments``).
Expand Down
2 changes: 1 addition & 1 deletion loki/tests/test_subroutine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2190,7 +2190,7 @@ def test_call_args_kwargs_conversion(frontend):

# sort kwargs
for i_call, call in enumerate(FindNodes(ir.CallStatement).visit(driver.body)):
assert call.check_kwarguments_order() == kwargs_in_order[i_call]
assert call.is_kwargs_order_correct() == kwargs_in_order[i_call]
call.sort_kwarguments()

# check calls with sorted kwargs
Expand Down
4 changes: 2 additions & 2 deletions loki/transformations/block_index_transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,8 +521,8 @@ def process(self, routine, targets, role):
"""
processed_routines = ()
variable_map = routine.variable_map
block_dim_index = variable_map[self.block_dim.index]
block_dim_size = variable_map[self.block_dim.size]
block_dim_index = get_integer_variable(routine, self.block_dim.index)
block_dim_size = get_integer_variable(routine, self.block_dim.size)
for call in FindNodes(ir.CallStatement).visit(routine.body):
if str(call.name).lower() not in targets:
continue
Expand Down
7 changes: 4 additions & 3 deletions loki/transformations/single_column/scc_cuf.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,14 +342,15 @@ def kernel_cuf(self, routine, horizontal, vertical, block_dim,
if horizontal_index.name in call.routine.variables:
call.routine.symbol_attrs.update({horizontal_index.name:\
call.routine.variable_map[horizontal_index.name].type.clone(intent='in')})
additional_args += (horizontal_index.clone(),)
additional_args += (horizontal_index.clone(type=horizontal_index.type.clone(intent='in'),
scope=call.routine),)
if horizontal_index.name not in call.arg_map:
additional_kwargs += ((horizontal_index.name, horizontal_index.clone()),)
additional_kwargs += ((horizontal_index.name, horizontal_index.clone(scope=routine)),)

if block_dim_index.name not in call.routine.arguments:
additional_args += (block_dim_index.clone(type=block_dim_index.type.clone(intent='in',
scope=call.routine)),)
additional_kwargs += ((block_dim_index.name, block_dim_index.clone()),)
additional_kwargs += ((block_dim_index.name, block_dim_index.clone(scope=routine)),)
if additional_kwargs:
call._update(kwarguments=call.kwarguments+additional_kwargs)
if additional_args:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ SUBROUTINE kernel(start, iend, nlon, nz, q, t, z)
INTEGER :: jl, jk
REAL :: c

c = 5.345
c = SOME_FUNC(A=5.345)
DO jk = 2, nz
DO jl = start, iend
call ELEMENTAL_DEVICE(z(jl, jk))
Expand Down Expand Up @@ -54,4 +54,11 @@ SUBROUTINE DEVICE(nlon, nz, jk_start, start, iend, x)
END DO
END SUBROUTINE DEVICE

FUNCTION SOME_FUNC(A)
REAL, INTENT(IN) :: A
REAL :: SOME_FUNC
!$loki routine seq
SOME_FUNC = A
END FUNCTION SOME_FUNC

END MODULE KERNEL_MOD
37 changes: 33 additions & 4 deletions loki/transformations/transpile/fortran_c.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from loki.ir import (
Section, Import, Intrinsic, Interface, CallStatement,
VariableDeclaration, TypeDef, Assignment, Transformer, FindNodes,
Pragma, Comment, SubstituteExpressions
Pragma, Comment, SubstituteExpressions, FindInlineCalls
)
from loki.logging import debug
from loki.module import Module
Expand Down Expand Up @@ -188,9 +188,8 @@ def transform_subroutine(self, routine, **kwargs):
if isinstance(arg.type.dtype, DerivedType):
self.c_structs[arg.type.dtype.name.lower()] = self.c_struct_typedef(arg.type)

for call in FindNodes(CallStatement).visit(routine.body):
if str(call.name).lower() in as_tuple(targets):
call.convert_kwargs_to_args()
# for calls and inline calls: convert kwarguments to arguments
self.convert_kwargs_to_args(routine, targets)

if role == 'kernel':
# Generate Fortran wrapper module
Expand Down Expand Up @@ -231,6 +230,19 @@ def transform_subroutine(self, routine, **kwargs):
header_path = (path/c_kernel.name.lower()).with_suffix('.h')
Sourcefile.to_file(source=self.codegen(c_kernel, header=True), path=header_path)

def convert_kwargs_to_args(self, routine, targets):
# calls (to subroutines)
for call in as_tuple(FindNodes(CallStatement).visit(routine.body)):
if str(call.name).lower() in as_tuple(targets):
call.convert_kwargs_to_args()
# inline calls (to functions)
inline_call_map = {}
for inline_call in as_tuple(FindInlineCalls().visit(routine.body)):
if str(inline_call.name).lower() in as_tuple(targets) and inline_call.routine is not BasicType.DEFERRED:
inline_call_map[inline_call] = inline_call.clone_with_kwargs_as_args()
if inline_call_map:
routine.body = SubstituteExpressions(inline_call_map).visit(routine.body)

def c_struct_typedef(self, derived):
"""
Create the :class:`TypeDef` for the C-wrapped struct definition.
Expand Down Expand Up @@ -627,6 +639,9 @@ def generate_c_kernel(self, routine, targets, **kwargs):
# apply dereference and reference where necessary
self.apply_de_reference(kernel)

# adapt call and inline call names -> '<call name>_c'
self.convert_call_names(kernel, targets)

symbol_map = {'epsilon': 'DBL_EPSILON'}
function_map = {'min': 'fmin', 'max': 'fmax', 'abs': 'fabs',
'exp': 'exp', 'sqrt': 'sqrt', 'sign': 'copysign'}
Expand All @@ -637,6 +652,20 @@ def generate_c_kernel(self, routine, targets, **kwargs):

return kernel

def convert_call_names(self, routine, targets):
# calls (to subroutines)
calls = FindNodes(CallStatement).visit(routine.body)
for call in calls:
if call.name not in as_tuple(targets):
continue
call._update(name=Variable(name=f'{call.name}_c'.lower()))
# inline calls (to functions)
callmap = {}
for call in FindInlineCalls(unique=False).visit(routine.body):
if call.routine is not BasicType.DEFERRED and (targets is None or call.name in as_tuple(targets)):
callmap[call.function] = call.function.clone(name=f'{call.name}_c')
routine.body = SubstituteExpressions(callmap).visit(routine.body)

def generate_c_kernel_launch(self, kernel_launch, kernel, **kwargs):
import_map = {}
for im in FindNodes(Import).visit(kernel_launch.spec):
Expand Down
44 changes: 44 additions & 0 deletions loki/transformations/transpile/tests/test_transpile.py
Original file line number Diff line number Diff line change
Expand Up @@ -1434,12 +1434,23 @@ def test_scc_cuda_parametrise(tmp_path, here, frontend, config, horizontal, vert
f2c_transformation = FortranCTransformation(path=tmp_path, language='cuda', use_c_ptr=True)
scheduler.process(transformation=f2c_transformation)

kernel = scheduler['kernel_mod#kernel'].ir
kernel_variable_map = kernel.variable_map
assert kernel_variable_map[horizontal.index].type.intent is None
assert kernel_variable_map[horizontal.index].scope == kernel
device = scheduler['kernel_mod#device'].ir
device_variable_map = device.variable_map
assert device_variable_map[horizontal.index].type.intent.lower() == 'in'
assert device_variable_map[horizontal.index].scope == device

fc_kernel = remove_whitespace_linebreaks(read_file(tmp_path/'kernel_fc.F90'))
c_kernel = remove_whitespace_linebreaks(read_file(tmp_path/'kernel_c.c'))
c_kernel_header = remove_whitespace_linebreaks(read_file(tmp_path/'kernel_c.h'))
c_kernel_launch = remove_whitespace_linebreaks(read_file(tmp_path/'kernel_c_launch.h'))
c_device = remove_whitespace_linebreaks(read_file(tmp_path/'device_c.c'))
c_elemental_device = remove_whitespace_linebreaks(read_file(tmp_path/'elemental_device_c.c'))
c_some_func = remove_whitespace_linebreaks(read_file(tmp_path/'some_func_c.c'))
c_some_func_header = remove_whitespace_linebreaks(read_file(tmp_path/'some_func_c.h'))

calls = FindNodes(ir.CallStatement).visit(scheduler["driver_mod#driver"].ir.body)
assert len(calls) == 3
Expand All @@ -1459,9 +1470,13 @@ def test_scc_cuda_parametrise(tmp_path, here, frontend, config, horizontal, vert
assert '#include"kernel_c_launch.h"' in c_kernel
assert 'include"elemental_device_c.h"' in c_kernel
assert 'include"device_c.h"' in c_kernel
assert 'include"some_func_c.h"' in c_kernel
assert '__global__voidkernel_c' in c_kernel
assert 'jl=threadidx.x;' in c_kernel
assert 'b=blockidx.x;' in c_kernel
assert 'device_c(' in c_kernel
assert 'elemental_device_c(' in c_kernel
assert '=some_func_c(' in c_kernel
# kernel_c.h
assert '__global__voidkernel_c' in c_kernel_header
assert 'jl=threadidx.x;' not in c_kernel_header
Expand All @@ -1479,9 +1494,17 @@ def test_scc_cuda_parametrise(tmp_path, here, frontend, config, horizontal, vert
assert '#include<cuda.h>' in c_device
assert '#include<cuda_runtime.h>' in c_device
assert '#include"device_c.h"' in c_device
# elemental_device_c.c
assert '__device__voiddevice_c(' in c_device
assert '#include<cuda.h>' in c_elemental_device
assert '#include<cuda_runtime.h>' in c_elemental_device
assert '#include"elemental_device_c.h"' in c_elemental_device
# some_func_c.c
assert 'doublesome_func_c(doublea)' in c_some_func
assert 'returnsome_func' in c_some_func
# some_func_c.h
assert 'doublesome_func_c(doublea);' in c_some_func_header


@pytest.mark.parametrize('frontend', available_frontends())
def test_scc_cuda_hoist(tmp_path, here, frontend, config, horizontal, vertical, blocking):
Expand All @@ -1503,12 +1526,23 @@ def test_scc_cuda_hoist(tmp_path, here, frontend, config, horizontal, vertical,
f2c_transformation = FortranCTransformation(path=tmp_path, language='cuda', use_c_ptr=True)
scheduler.process(transformation=f2c_transformation)

kernel = scheduler['kernel_mod#kernel'].ir
kernel_variable_map = kernel.variable_map
assert kernel_variable_map[horizontal.index].type.intent is None
assert kernel_variable_map[horizontal.index].scope == kernel
device = scheduler['kernel_mod#device'].ir
device_variable_map = device.variable_map
assert device_variable_map[horizontal.index].type.intent.lower() == 'in'
assert device_variable_map[horizontal.index].scope == device

fc_kernel = remove_whitespace_linebreaks(read_file(tmp_path/'kernel_fc.F90'))
c_kernel = remove_whitespace_linebreaks(read_file(tmp_path/'kernel_c.c'))
c_kernel_header = remove_whitespace_linebreaks(read_file(tmp_path/'kernel_c.h'))
c_kernel_launch = remove_whitespace_linebreaks(read_file(tmp_path/'kernel_c_launch.h'))
c_device = remove_whitespace_linebreaks(read_file(tmp_path/'device_c.c'))
c_elemental_device = remove_whitespace_linebreaks(read_file(tmp_path/'elemental_device_c.c'))
c_some_func = remove_whitespace_linebreaks(read_file(tmp_path/'some_func_c.c'))
c_some_func_header = remove_whitespace_linebreaks(read_file(tmp_path/'some_func_c.h'))

calls = FindNodes(ir.CallStatement).visit(scheduler["driver_mod#driver"].ir.body)
assert len(calls) == 3
Expand All @@ -1529,9 +1563,13 @@ def test_scc_cuda_hoist(tmp_path, here, frontend, config, horizontal, vertical,
assert '#include"kernel_c_launch.h"' in c_kernel
assert '#include"elemental_device_c.h"' in c_kernel
assert '#include"device_c.h"' in c_kernel
assert 'include"some_func_c.h"' in c_kernel
assert '__global__voidkernel_c' in c_kernel
assert 'jl=threadidx.x;' in c_kernel
assert 'b=blockidx.x;' in c_kernel
assert 'device_c(' in c_kernel
assert 'elemental_device_c(' in c_kernel
assert '=some_func_c(' in c_kernel
# kernel_c.h
assert '__global__voidkernel_c' in c_kernel_header
assert 'jl=threadidx.x;' not in c_kernel_header
Expand All @@ -1549,10 +1587,16 @@ def test_scc_cuda_hoist(tmp_path, here, frontend, config, horizontal, vertical,
assert '#include<cuda.h>' in c_device
assert '#include<cuda_runtime.h>' in c_device
assert '#include"device_c.h"' in c_device
assert '__device__voiddevice_c(' in c_device
# elemental_device_c.c
assert '#include<cuda.h>' in c_elemental_device
assert '#include<cuda_runtime.h>' in c_elemental_device
assert '#include"elemental_device_c.h"' in c_elemental_device
# some_func_c.c
assert 'doublesome_func_c(doublea)' in c_some_func
assert 'returnsome_func' in c_some_func
# some_func_c.h
assert 'doublesome_func_c(doublea);' in c_some_func_header


@pytest.mark.parametrize('frontend', available_frontends())
Expand Down

0 comments on commit b88920c

Please sign in to comment.