Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

F2C: DeReferenceTrafo #273

Merged
merged 4 commits into from
Apr 11, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion loki/backend/cgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ def visit_CallStatement(self, o, **kwargs):
"""
args = self.visit_all(o.arguments, **kwargs)
assert not o.kwarguments
return self.format_line(o.name, '(', self.join_items(args), ');')
return self.format_line(str(o.name), '(', self.join_items(args), ');')

def visit_SymbolAttributes(self, o, **kwargs): # pylint: disable=unused-argument
if isinstance(o.dtype, DerivedType):
Expand Down
72 changes: 67 additions & 5 deletions loki/transform/fortran_c_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@
from loki.module import Module
from loki.expression import (
Variable, InlineCall, RangeIndex, Scalar, Array,
ProcedureSymbol, SubstituteExpressions, Dereference,
ProcedureSymbol, SubstituteExpressions, Dereference, Reference,
ExpressionRetriever, SubstituteExpressionsMapper,
)
from loki.expression import symbols as sym
from loki.tools import as_tuple, flatten
Expand All @@ -40,6 +41,54 @@
__all__ = ['FortranCTransformation']


class DeReferenceTrafo(Transformer):
"""
Transformation to apply/insert Dereference = `*` and
Reference/*address-of* = `&` operators.

Parameters
----------
vars2dereference : list
Variables to be dereferenced. Ususally the arguments except
for scalars with `intent=in`.
"""
# pylint: disable=unused-argument

def __init__(self, vars2dereference):
super().__init__()
self.retriever = ExpressionRetriever(self.is_dereference)
self.vars2dereference = vars2dereference

@staticmethod
def is_dereference(symbol):
return isinstance(symbol, (DerivedType, Array, Scalar)) and not (
isinstance(symbol, Array) and symbol.dimensions is not None
and not all(dim == sym.RangeIndex((None, None)) for dim in symbol.dimensions)
)

def visit_Expression(self, o, **kwargs):
symbol_map = {
symbol: Dereference(symbol.clone()) for symbol in self.retriever.retrieve(o)
if symbol.name.lower() in self.vars2dereference
}
return SubstituteExpressionsMapper(symbol_map)(o)

def visit_CallStatement(self, o, **kwargs):
new_args = ()
call_arg_map = dict((v,k) for k,v in o.arg_map.items())
for arg in o.arguments:
if not self.is_dereference(arg) and (isinstance(call_arg_map[arg], Array)\
or call_arg_map[arg].type.intent.lower() != 'in'):
new_args += (Reference(arg.clone()),)
else:
if isinstance(arg, Scalar) and call_arg_map[arg].type.intent.lower() != 'in':
new_args += (Reference(arg.clone()),)
else:
new_args += (arg,)
o._update(arguments=new_args)
return o


class FortranCTransformation(Transformation):
"""
Fortran-to-C transformation that translates the given routine
Expand Down Expand Up @@ -401,6 +450,19 @@ def generate_c_header(self, module, **kwargs):
header_module.rescope_symbols()
return header_module

@staticmethod
def apply_de_reference(routine):
"""
Utility method to apply/insert Dereference = `*` and
Reference/*address-of* = `&` operators.
"""
to_be_dereferenced = []
for arg in routine.arguments:
if not(arg.type.intent.lower() == 'in' and isinstance(arg, Scalar)):
to_be_dereferenced.append(arg.name.lower())

routine.body = DeReferenceTrafo(to_be_dereferenced).visit(routine.body)

def generate_c_kernel(self, routine):
"""
Re-generate the C kernel and insert wrapper-specific peculiarities,
Expand Down Expand Up @@ -476,18 +538,18 @@ def generate_c_kernel(self, routine):
# Force all variables to lower-caps, as C/C++ is case-sensitive
convert_to_lower_case(kernel)

# Force pointer on reference-passed arguments
var_map = {}
# Force pointer on reference-passed arguments (and lower case type names for derived types)
for arg in kernel.arguments:
if not(arg.type.intent.lower() == 'in' and isinstance(arg, Scalar)):
_type = arg.type.clone(pointer=True)
if isinstance(arg.type.dtype, DerivedType):
# Lower case type names for derived types
typedef = _type.dtype.typedef.clone(name=_type.dtype.typedef.name.lower())
_type = _type.clone(dtype=typedef.dtype)
var_map[arg] = Dereference(arg)
kernel.symbol_attrs[arg.name] = _type
kernel.body = SubstituteExpressions(var_map).visit(kernel.body)

# apply dereference and reference where necessary
self.apply_de_reference(kernel)

symbol_map = {'epsilon': 'DBL_EPSILON'}
function_map = {'min': 'fmin', 'max': 'fmax', 'abs': 'fabs',
Expand Down
54 changes: 54 additions & 0 deletions tests/test_transpile.py
Original file line number Diff line number Diff line change
Expand Up @@ -1004,6 +1004,60 @@ def test_transpile_expressions(here, builder, frontend, use_c_ptr):
f2c.wrapperpath.unlink()
f2c.c_path.unlink()


@pytest.mark.parametrize('use_c_ptr', (False, True))
@pytest.mark.parametrize('frontend', available_frontends())
def test_transpile_call(here, frontend, use_c_ptr):
fcode_module = """
module transpile_call_kernel_mod
implicit none
contains

subroutine transpile_call_kernel(a, b, c, arr1, len)
integer, intent(inout) :: a, c
integer, intent(in) :: b
integer, intent(in) :: len
integer, intent(inout) :: arr1(len, len)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For my education: What is the intended outcome if we had (another) argument arr(:,:) here? Would this pre-empt the F2C transformation until we figure out the actual shape, or would it simply not be passed as a reference?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we need the actual shape information for flattening (see flatten_arrays() - TypeError: Resolve shapes being of type RangeIndex, e.g., ":" before flattening!, raised if isinstance(shape[-2], sym.RangeIndex))arrays having more than one dimension, the DeReferenceTrafo wouldn't be executed.

However,

subroutine transpile_call_driver(a)
  use transpile_call_kernel_mod, only: transpile_call_kernel
    integer, parameter :: len = 5
    integer, intent(inout) :: arr1(len)
    integer, intent(inout) :: arr2(len)
    call transpile_call_kernel(arr1, arr2, len)
end subroutine transpile_call_driver

  subroutine transpile_call_kernel(arr1, arr2, len)
    integer, intent(in) :: len
    integer, intent(inout) :: arr1(len)
    integer, intent(inout) :: arr2(:)

    arr1(1) = 1
    arr2(1) = 1
  end subroutine transpile_call_kernel

is transformed/transpiled to:

int transpile_call_driver_c() {
  int len = 5;
  transpile_call_kernel(arr1, arr2, len);
  return 0;
}

int transpile_call_kernel_c(int * restrict arr1, int * restrict arr2, int len) {

  int arr1[len];
  int arr2[len];
  arr1[1 - 1] = 1;
  arr2[1 - 1] = 1;
  return 0;
}

if that answers your question?!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It does, many thanks!

a = b
c = b
end subroutine transpile_call_kernel
end module transpile_call_kernel_mod
"""

fcode = """
subroutine transpile_call_driver(a)
use transpile_call_kernel_mod, only: transpile_call_kernel
integer, intent(inout) :: a
integer, parameter :: len = 5
integer :: arr1(len, len)
integer :: arr2(len, len)
integer :: b
b = 2 * len
call transpile_call_kernel(a, b, arr2(1, 1), arr1, len)
end subroutine transpile_call_driver
"""
unlink_paths = []
module = Module.from_source(fcode_module, frontend=frontend)
routine = Subroutine.from_source(fcode, frontend=frontend, definitions=module)
f2c = FortranCTransformation(use_c_ptr=use_c_ptr, path=here)
f2c.apply(source=module.subroutine_map['transpile_call_kernel'], path=here, role='kernel')
unlink_paths.extend([f2c.wrapperpath, f2c.c_path])
ccode_kernel = f2c.c_path.read_text().replace(' ', '').replace('\n', '')
f2c.apply(source=routine, path=here, role='kernel')
unlink_paths.extend([f2c.wrapperpath, f2c.c_path])
ccode_driver = f2c.c_path.read_text().replace(' ', '').replace('\n', '')

assert "int*a,intb,int*c" in ccode_kernel
# check for applied Dereference
assert "(*a)=b;" in ccode_kernel
assert "(*c)=b;" in ccode_kernel
# check for applied Reference
assert "transpile_call_kernel((&a),b,(&arr2[" in ccode_driver

for path in unlink_paths:
path.unlink()


@pytest.mark.parametrize('frontend', available_frontends())
def test_transpile_multiconditional(here, builder, frontend):
"""
Expand Down
Loading