diff --git a/loki/backend/cgen.py b/loki/backend/cgen.py index f74502386..7635f5e1e 100644 --- a/loki/backend/cgen.py +++ b/loki/backend/cgen.py @@ -362,7 +362,7 @@ def visit_CallStatement(self, o, **kwargs): """ args = self.visit_all(o.arguments, **kwargs) assert not o.kwarguments - return self.format_line(o.name, '(', self.join_items(args), ');') + return self.format_line(str(o.name), '(', self.join_items(args), ');') def visit_SymbolAttributes(self, o, **kwargs): # pylint: disable=unused-argument if isinstance(o.dtype, DerivedType): diff --git a/loki/transform/fortran_c_transform.py b/loki/transform/fortran_c_transform.py index ce76fdb52..c96e73939 100644 --- a/loki/transform/fortran_c_transform.py +++ b/loki/transform/fortran_c_transform.py @@ -31,7 +31,8 @@ from loki.module import Module from loki.expression import ( Variable, InlineCall, RangeIndex, Scalar, Array, - ProcedureSymbol, SubstituteExpressions, Dereference, + ProcedureSymbol, SubstituteExpressions, Dereference, Reference, + ExpressionRetriever, SubstituteExpressionsMapper, ) from loki.expression import symbols as sym from loki.tools import as_tuple, flatten @@ -40,6 +41,54 @@ __all__ = ['FortranCTransformation'] +class DeReferenceTrafo(Transformer): + """ + Transformation to apply/insert Dereference = `*` and + Reference/*address-of* = `&` operators. + + Parameters + ---------- + vars2dereference : list + Variables to be dereferenced. Ususally the arguments except + for scalars with `intent=in`. + """ + # pylint: disable=unused-argument + + def __init__(self, vars2dereference): + super().__init__() + self.retriever = ExpressionRetriever(self.is_dereference) + self.vars2dereference = vars2dereference + + @staticmethod + def is_dereference(symbol): + return isinstance(symbol, (DerivedType, Array, Scalar)) and not ( + isinstance(symbol, Array) and symbol.dimensions is not None + and not all(dim == sym.RangeIndex((None, None)) for dim in symbol.dimensions) + ) + + def visit_Expression(self, o, **kwargs): + symbol_map = { + symbol: Dereference(symbol.clone()) for symbol in self.retriever.retrieve(o) + if symbol.name.lower() in self.vars2dereference + } + return SubstituteExpressionsMapper(symbol_map)(o) + + def visit_CallStatement(self, o, **kwargs): + new_args = () + call_arg_map = dict((v,k) for k,v in o.arg_map.items()) + for arg in o.arguments: + if not self.is_dereference(arg) and (isinstance(call_arg_map[arg], Array)\ + or call_arg_map[arg].type.intent.lower() != 'in'): + new_args += (Reference(arg.clone()),) + else: + if isinstance(arg, Scalar) and call_arg_map[arg].type.intent.lower() != 'in': + new_args += (Reference(arg.clone()),) + else: + new_args += (arg,) + o._update(arguments=new_args) + return o + + class FortranCTransformation(Transformation): """ Fortran-to-C transformation that translates the given routine @@ -401,6 +450,19 @@ def generate_c_header(self, module, **kwargs): header_module.rescope_symbols() return header_module + @staticmethod + def apply_de_reference(routine): + """ + Utility method to apply/insert Dereference = `*` and + Reference/*address-of* = `&` operators. + """ + to_be_dereferenced = [] + for arg in routine.arguments: + if not(arg.type.intent.lower() == 'in' and isinstance(arg, Scalar)): + to_be_dereferenced.append(arg.name.lower()) + + routine.body = DeReferenceTrafo(to_be_dereferenced).visit(routine.body) + def generate_c_kernel(self, routine): """ Re-generate the C kernel and insert wrapper-specific peculiarities, @@ -476,8 +538,7 @@ def generate_c_kernel(self, routine): # Force all variables to lower-caps, as C/C++ is case-sensitive convert_to_lower_case(kernel) - # Force pointer on reference-passed arguments - var_map = {} + # Force pointer on reference-passed arguments (and lower case type names for derived types) for arg in kernel.arguments: if not(arg.type.intent.lower() == 'in' and isinstance(arg, Scalar)): _type = arg.type.clone(pointer=True) @@ -485,9 +546,10 @@ def generate_c_kernel(self, routine): # Lower case type names for derived types typedef = _type.dtype.typedef.clone(name=_type.dtype.typedef.name.lower()) _type = _type.clone(dtype=typedef.dtype) - var_map[arg] = Dereference(arg) kernel.symbol_attrs[arg.name] = _type - kernel.body = SubstituteExpressions(var_map).visit(kernel.body) + + # apply dereference and reference where necessary + self.apply_de_reference(kernel) symbol_map = {'epsilon': 'DBL_EPSILON'} function_map = {'min': 'fmin', 'max': 'fmax', 'abs': 'fabs', diff --git a/tests/test_transpile.py b/tests/test_transpile.py index 1191daa4c..4d579b7f0 100644 --- a/tests/test_transpile.py +++ b/tests/test_transpile.py @@ -1005,6 +1005,59 @@ def test_transpile_expressions(here, builder, frontend, use_c_ptr): f2c.c_path.unlink() +@pytest.mark.parametrize('use_c_ptr', (False, True)) +@pytest.mark.parametrize('frontend', available_frontends()) +def test_transpile_call(here, frontend, use_c_ptr): + fcode_module = """ +module transpile_call_kernel_mod + implicit none +contains + + subroutine transpile_call_kernel(a, b, c, arr1, len) + integer, intent(inout) :: a, c + integer, intent(in) :: b + integer, intent(in) :: len + integer, intent(inout) :: arr1(len, len) + a = b + c = b + end subroutine transpile_call_kernel +end module transpile_call_kernel_mod +""" + + fcode = """ +subroutine transpile_call_driver(a) + use transpile_call_kernel_mod, only: transpile_call_kernel + integer, intent(inout) :: a + integer, parameter :: len = 5 + integer :: arr1(len, len) + integer :: arr2(len, len) + integer :: b + b = 2 * len + call transpile_call_kernel(a, b, arr2(1, 1), arr1, len) +end subroutine transpile_call_driver +""" + unlink_paths = [] + module = Module.from_source(fcode_module, frontend=frontend) + routine = Subroutine.from_source(fcode, frontend=frontend, definitions=module) + f2c = FortranCTransformation(use_c_ptr=use_c_ptr, path=here) + f2c.apply(source=module.subroutine_map['transpile_call_kernel'], path=here, role='kernel') + unlink_paths.extend([f2c.wrapperpath, f2c.c_path]) + ccode_kernel = f2c.c_path.read_text().replace(' ', '').replace('\n', '') + f2c.apply(source=routine, path=here, role='kernel') + unlink_paths.extend([f2c.wrapperpath, f2c.c_path]) + ccode_driver = f2c.c_path.read_text().replace(' ', '').replace('\n', '') + + assert "int*a,intb,int*c" in ccode_kernel + # check for applied Dereference + assert "(*a)=b;" in ccode_kernel + assert "(*c)=b;" in ccode_kernel + # check for applied Reference + assert "transpile_call_kernel((&a),b,(&arr2[" in ccode_driver + + for path in unlink_paths: + path.unlink() + + @pytest.mark.parametrize('frontend', available_frontends()) @pytest.mark.parametrize('f_type', ['integer', 'real']) def test_transpile_inline_functions(here, frontend, f_type):