From 96a812b9526144aa518fefefec6889a983550ec2 Mon Sep 17 00:00:00 2001 From: Michael Staneker Date: Fri, 5 Apr 2024 10:53:37 +0000 Subject: [PATCH 1/2] F2C: 'DeReferenceTrafo' to apply 'Dereference' and 'Reference' where needed --- loki/backend/cgen.py | 2 +- loki/transform/fortran_c_transform.py | 43 +++++++++++++++++++--- tests/test_transpile.py | 52 +++++++++++++++++++++++++++ 3 files changed, 92 insertions(+), 5 deletions(-) diff --git a/loki/backend/cgen.py b/loki/backend/cgen.py index 282c05969..d5ac7ecd7 100644 --- a/loki/backend/cgen.py +++ b/loki/backend/cgen.py @@ -349,7 +349,7 @@ def visit_CallStatement(self, o, **kwargs): """ args = self.visit_all(o.arguments, **kwargs) assert not o.kwarguments - return self.format_line(o.name, '(', self.join_items(args), ');') + return self.format_line(str(o.name), '(', self.join_items(args), ');') def visit_SymbolAttributes(self, o, **kwargs): # pylint: disable=unused-argument if isinstance(o.dtype, DerivedType): diff --git a/loki/transform/fortran_c_transform.py b/loki/transform/fortran_c_transform.py index dd0aa2e00..c66d1d427 100644 --- a/loki/transform/fortran_c_transform.py +++ b/loki/transform/fortran_c_transform.py @@ -31,7 +31,8 @@ from loki.module import Module from loki.expression import ( Variable, InlineCall, RangeIndex, Scalar, Array, - ProcedureSymbol, SubstituteExpressions, Dereference, + ProcedureSymbol, SubstituteExpressions, Dereference, Reference, + ExpressionRetriever, SubstituteExpressionsMapper, ) from loki.expression import symbols as sym from loki.tools import as_tuple, flatten @@ -477,7 +478,7 @@ def generate_c_kernel(self, routine): convert_to_lower_case(kernel) # Force pointer on reference-passed arguments - var_map = {} + to_be_dereferenced = [] for arg in kernel.arguments: if not(arg.type.intent.lower() == 'in' and isinstance(arg, Scalar)): _type = arg.type.clone(pointer=True) @@ -485,9 +486,43 @@ def generate_c_kernel(self, routine): # Lower case type names for derived types typedef = _type.dtype.typedef.clone(name=_type.dtype.typedef.name.lower()) _type = _type.clone(dtype=typedef.dtype) - var_map[arg] = Dereference(arg) + to_be_dereferenced.append(arg.name.lower()) kernel.symbol_attrs[arg.name] = _type - kernel.body = SubstituteExpressions(var_map).visit(kernel.body) + + class DeReferenceTrafo(Transformer): + + def __init__(self, vars2dereference): + super().__init__() + self.retriever = ExpressionRetriever(lambda e: isinstance(e, (DerivedType, Array, Scalar))\ + and e.name.lower() in vars2dereference) + + def visit_Expression(self, o, **kwargs): + symbols = self.retriever.retrieve(o) + symbol_map = {} + for symbol in symbols: + if isinstance(symbol, Array) and symbol.dimensions is not None\ + and not all(dim == sym.RangeIndex((None, None)) for dim in symbol.dimensions): + continue + symbol_map[symbol] = Dereference(symbol.clone()) + return SubstituteExpressionsMapper(symbol_map)(o) + + def visit_CallStatement(self, o, **kwargs): + new_args = () + call_arg_map = dict((v,k) for k,v in o.arg_map.items()) + for arg in o.arguments: + if isinstance(arg, Array) and arg.dimensions\ + and all(dim != sym.RangeIndex((None, None)) for dim in arg.dimensions) \ + and (isinstance(call_arg_map[arg], Array) or call_arg_map[arg].type.intent.lower() != 'in'): + new_args += (Reference(arg.clone()),) + else: + if isinstance(arg, Scalar) and call_arg_map[arg].type.intent.lower() != 'in': + new_args += (Reference(arg.clone()),) + else: + new_args += (arg,) + o._update(arguments=new_args) + return o + + kernel.body = DeReferenceTrafo(to_be_dereferenced).visit(kernel.body) symbol_map = {'epsilon': 'DBL_EPSILON'} function_map = {'min': 'fmin', 'max': 'fmax', 'abs': 'fabs', diff --git a/tests/test_transpile.py b/tests/test_transpile.py index 7be234366..5d518e1a6 100644 --- a/tests/test_transpile.py +++ b/tests/test_transpile.py @@ -1003,3 +1003,55 @@ def test_transpile_expressions(here, builder, frontend, use_c_ptr): clean_test(filepath) f2c.wrapperpath.unlink() f2c.c_path.unlink() + +@pytest.mark.parametrize('use_c_ptr', (False, True)) +@pytest.mark.parametrize('frontend', available_frontends()) +def test_transpile_call(here, frontend, use_c_ptr): + fcode_module = """ +module transpile_call_kernel_mod + implicit none +contains + + subroutine transpile_call_kernel(a, b, c, arr1, len) + integer, intent(inout) :: a, c + integer, intent(in) :: b + integer, intent(in) :: len + integer, intent(inout) :: arr1(len, len) + a = b + c = b + end subroutine transpile_call_kernel +end module transpile_call_kernel_mod +""" + + fcode = """ +subroutine transpile_call_driver(a) + use transpile_call_kernel_mod, only: transpile_call_kernel + integer, intent(inout) :: a + integer, parameter :: len = 5 + integer :: arr1(len, len) + integer :: arr2(len, len) + integer :: b + b = 2 * len + call transpile_call_kernel(a, b, arr2(1, 1), arr1, len) +end subroutine transpile_call_driver +""" + unlink_paths = [] + module = Module.from_source(fcode_module, frontend=frontend) + routine = Subroutine.from_source(fcode, frontend=frontend, definitions=module) + f2c = FortranCTransformation(use_c_ptr=use_c_ptr, path=here) + f2c.apply(source=module.subroutine_map['transpile_call_kernel'], path=here, role='kernel') + unlink_paths.extend([f2c.wrapperpath, f2c.c_path]) + ccode_kernel = f2c.c_path.read_text().replace(' ', '').replace('\n', '') + f2c.apply(source=routine, path=here, role='kernel') + unlink_paths.extend([f2c.wrapperpath, f2c.c_path]) + ccode_driver = f2c.c_path.read_text().replace(' ', '').replace('\n', '') + + assert "int*a,intb,int*c" in ccode_kernel + # check for applied Dereference + assert "(*a)=b;" in ccode_kernel + assert "(*c)=b;" in ccode_kernel + # check for applied Reference + assert "transpile_call_kernel((&a),b,(&arr2[" in ccode_driver + + for path in unlink_paths: + path.unlink() From b2fca8e7ef36d956c4e842dc72835516152dfbba Mon Sep 17 00:00:00 2001 From: Michael Staneker Date: Wed, 10 Apr 2024 11:32:33 +0000 Subject: [PATCH 2/2] F2C: 'DeReferenceTransfo': improve readability and modularisation --- loki/transform/fortran_c_transform.py | 101 ++++++++++++++++---------- 1 file changed, 64 insertions(+), 37 deletions(-) diff --git a/loki/transform/fortran_c_transform.py b/loki/transform/fortran_c_transform.py index c66d1d427..4d9a7a8a0 100644 --- a/loki/transform/fortran_c_transform.py +++ b/loki/transform/fortran_c_transform.py @@ -41,6 +41,54 @@ __all__ = ['FortranCTransformation'] +class DeReferenceTrafo(Transformer): + """ + Transformation to apply/insert Dereference = `*` and + Reference/*address-of* = `&` operators. + + Parameters + ---------- + vars2dereference : list + Variables to be dereferenced. Ususally the arguments except + for scalars with `intent=in`. + """ + # pylint: disable=unused-argument + + def __init__(self, vars2dereference): + super().__init__() + self.retriever = ExpressionRetriever(self.is_dereference) + self.vars2dereference = vars2dereference + + @staticmethod + def is_dereference(symbol): + return isinstance(symbol, (DerivedType, Array, Scalar)) and not ( + isinstance(symbol, Array) and symbol.dimensions is not None + and not all(dim == sym.RangeIndex((None, None)) for dim in symbol.dimensions) + ) + + def visit_Expression(self, o, **kwargs): + symbol_map = { + symbol: Dereference(symbol.clone()) for symbol in self.retriever.retrieve(o) + if symbol.name.lower() in self.vars2dereference + } + return SubstituteExpressionsMapper(symbol_map)(o) + + def visit_CallStatement(self, o, **kwargs): + new_args = () + call_arg_map = dict((v,k) for k,v in o.arg_map.items()) + for arg in o.arguments: + if not self.is_dereference(arg) and (isinstance(call_arg_map[arg], Array)\ + or call_arg_map[arg].type.intent.lower() != 'in'): + new_args += (Reference(arg.clone()),) + else: + if isinstance(arg, Scalar) and call_arg_map[arg].type.intent.lower() != 'in': + new_args += (Reference(arg.clone()),) + else: + new_args += (arg,) + o._update(arguments=new_args) + return o + + class FortranCTransformation(Transformation): """ Fortran-to-C transformation that translates the given routine @@ -402,6 +450,19 @@ def generate_c_header(self, module, **kwargs): header_module.rescope_symbols() return header_module + @staticmethod + def apply_de_reference(routine): + """ + Utility method to apply/insert Dereference = `*` and + Reference/*address-of* = `&` operators. + """ + to_be_dereferenced = [] + for arg in routine.arguments: + if not(arg.type.intent.lower() == 'in' and isinstance(arg, Scalar)): + to_be_dereferenced.append(arg.name.lower()) + + routine.body = DeReferenceTrafo(to_be_dereferenced).visit(routine.body) + def generate_c_kernel(self, routine): """ Re-generate the C kernel and insert wrapper-specific peculiarities, @@ -477,8 +538,7 @@ def generate_c_kernel(self, routine): # Force all variables to lower-caps, as C/C++ is case-sensitive convert_to_lower_case(kernel) - # Force pointer on reference-passed arguments - to_be_dereferenced = [] + # Force pointer on reference-passed arguments (and lower case type names for derived types) for arg in kernel.arguments: if not(arg.type.intent.lower() == 'in' and isinstance(arg, Scalar)): _type = arg.type.clone(pointer=True) @@ -486,43 +546,10 @@ def generate_c_kernel(self, routine): # Lower case type names for derived types typedef = _type.dtype.typedef.clone(name=_type.dtype.typedef.name.lower()) _type = _type.clone(dtype=typedef.dtype) - to_be_dereferenced.append(arg.name.lower()) kernel.symbol_attrs[arg.name] = _type - class DeReferenceTrafo(Transformer): - - def __init__(self, vars2dereference): - super().__init__() - self.retriever = ExpressionRetriever(lambda e: isinstance(e, (DerivedType, Array, Scalar))\ - and e.name.lower() in vars2dereference) - - def visit_Expression(self, o, **kwargs): - symbols = self.retriever.retrieve(o) - symbol_map = {} - for symbol in symbols: - if isinstance(symbol, Array) and symbol.dimensions is not None\ - and not all(dim == sym.RangeIndex((None, None)) for dim in symbol.dimensions): - continue - symbol_map[symbol] = Dereference(symbol.clone()) - return SubstituteExpressionsMapper(symbol_map)(o) - - def visit_CallStatement(self, o, **kwargs): - new_args = () - call_arg_map = dict((v,k) for k,v in o.arg_map.items()) - for arg in o.arguments: - if isinstance(arg, Array) and arg.dimensions\ - and all(dim != sym.RangeIndex((None, None)) for dim in arg.dimensions) \ - and (isinstance(call_arg_map[arg], Array) or call_arg_map[arg].type.intent.lower() != 'in'): - new_args += (Reference(arg.clone()),) - else: - if isinstance(arg, Scalar) and call_arg_map[arg].type.intent.lower() != 'in': - new_args += (Reference(arg.clone()),) - else: - new_args += (arg,) - o._update(arguments=new_args) - return o - - kernel.body = DeReferenceTrafo(to_be_dereferenced).visit(kernel.body) + # apply dereference and reference where necessary + self.apply_de_reference(kernel) symbol_map = {'epsilon': 'DBL_EPSILON'} function_map = {'min': 'fmin', 'max': 'fmax', 'abs': 'fabs',