Skip to content

Commit

Permalink
[F2C] started refactoring 'cgen' and introducing c-like codegen ('cpp…
Browse files Browse the repository at this point in the history
…gen', 'cudagen')
  • Loading branch information
MichaelSt98 committed Apr 10, 2024
1 parent 5772340 commit 962528e
Show file tree
Hide file tree
Showing 7 changed files with 356 additions and 44 deletions.
2 changes: 2 additions & 0 deletions loki/backend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

from loki.backend.fgen import * # noqa
from loki.backend.cgen import * # noqa
from loki.backend.cppgen import * # noqa
from loki.backend.cudagen import * # noqa
from loki.backend.maxgen import * # noqa
from loki.backend.pygen import * # noqa
from loki.backend.dacegen import * # noqa
Expand Down
138 changes: 100 additions & 38 deletions loki/backend/cgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,37 +17,52 @@
symbols as sym
)
from loki.types import BasicType, SymbolAttributes, DerivedType
__all__ = ['cgen', 'CCodegen', 'CCodeMapper']
__all__ = ['cgen', 'CCodegen', 'CCodeMapper', 'IntrinsicTypeC']


def c_intrinsic_type(_type):
if _type.dtype == BasicType.LOGICAL:
return 'int'
if _type.dtype == BasicType.INTEGER:
return 'int'
if _type.dtype == BasicType.REAL:
if str(_type.kind) in ['real32']:
return 'float'
return 'double'
raise ValueError(str(_type))
class IntrinsicTypeC:
# pylint: disable=abstract-method, unused-argument

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def __call__(self, _type, *args, **kwargs):
return self.c_intrinsic_type(_type, *args, **kwargs)

def c_intrinsic_type(self, _type, *args, **kwargs):
if _type.dtype == BasicType.LOGICAL:
return 'int'
if _type.dtype == BasicType.INTEGER:
return 'int'
if _type.dtype == BasicType.REAL:
if str(_type.kind) in ['real32']:
return 'float'
return 'double'
raise ValueError(str(_type))

# pylint: disable=redefined-outer-name
c_intrinsic_type = IntrinsicTypeC()

class CCodeMapper(LokiStringifyMapper):
# pylint: disable=abstract-method, unused-argument

def __init__(self, c_intrinsic_type, *args, **kwargs):
super().__init__()
self.c_intrinsic_type = c_intrinsic_type

def map_logic_literal(self, expr, enclosing_prec, *args, **kwargs):
return super().map_logic_literal(expr, enclosing_prec, *args, **kwargs).lower()

def map_float_literal(self, expr, enclosing_prec, *args, **kwargs):
if expr.kind is not None:
_type = SymbolAttributes(BasicType.REAL, kind=expr.kind)
return f'({c_intrinsic_type(_type)}) {str(expr.value)}'
return f'({self.c_intrinsic_type(_type)}) {str(expr.value)}'
return str(expr.value)

def map_int_literal(self, expr, enclosing_prec, *args, **kwargs):
if expr.kind is not None:
_type = SymbolAttributes(BasicType.INTEGER, kind=expr.kind)
return f'({c_intrinsic_type(_type)}) {str(expr.value)}'
return f'({self.c_intrinsic_type(_type)}) {str(expr.value)}'
return str(expr.value)

def map_string_literal(self, expr, enclosing_prec, *args, **kwargs):
Expand All @@ -59,7 +74,7 @@ def map_cast(self, expr, enclosing_prec, *args, **kwargs):
self.join_rec('', expr.parameters, PREC_NONE, *args, **kwargs),
PREC_CALL, PREC_NONE)
return self.parenthesize_if_needed(
self.format('(%s) %s', c_intrinsic_type(_type), expression), enclosing_prec, PREC_CALL)
self.format('(%s) %s', self.c_intrinsic_type(_type), expression), enclosing_prec, PREC_CALL)

def map_variable_symbol(self, expr, enclosing_prec, *args, **kwargs):
if expr.parent is not None:
Expand Down Expand Up @@ -122,10 +137,15 @@ class CCodegen(Stringifier):
"""
Tree visitor to generate standardized C code from IR.
"""
# pylint: disable=abstract-method, unused-argument

def __init__(self, depth=0, indent=' ', linewidth=90):
standard_imports = ['stdio.h', 'stdbool.h', 'float.h', 'math.h']

def __init__(self, depth=0, indent=' ', linewidth=90, **kwargs):
symgen = kwargs.get('symgen', CCodeMapper(c_intrinsic_type))
line_cont = kwargs.get('line_cont', '\n{} '.format)
super().__init__(depth=depth, indent=indent, linewidth=linewidth,
line_cont='\n{} '.format, symgen=CCodeMapper())
line_cont=line_cont, symgen=symgen)

# Handler for outer objects

Expand All @@ -143,25 +163,32 @@ def visit_Module(self, o, **kwargs):
routines = self.visit(o.routines, **kwargs)
return self.join_lines(spec, routines)

def visit_Subroutine(self, o, **kwargs):
"""
Format as:
...imports...
int <name>(<args>) {
...spec without imports and argument declarations...
...body...
}
"""
def _subroutine_header(self, o, **kwargs):
# Some boilerplate imports...
standard_imports = ['stdio.h', 'stdbool.h', 'float.h', 'math.h']
header = [self.format_line('#include <', name, '>') for name in standard_imports]

header = [self.format_line('#include <', name, '>') for name in self.standard_imports]
# ...and imports from the spec
spec_imports = FindNodes(Import).visit(o.spec)
header += [self.visit(spec_imports, **kwargs)]
return header

def _subroutine_arguments(self, o, **kwargs):
var_keywords = []
pass_by = []
for a in o.arguments:
var_keywords += ['']
if isinstance(a, Array) > 0:
pass_by += ['* restrict ']
elif isinstance(a.type.dtype, DerivedType):
pass_by += ['*']
elif a.type.pointer:
pass_by += ['*']
else:
pass_by += ['']
return pass_by, var_keywords

def _subroutine_declaration(self, o, **kwargs):
# Generate header with argument signature
"""
aptr = []
for a in o.arguments:
if isinstance(a, Array) > 0:
Expand All @@ -172,12 +199,19 @@ def visit_Subroutine(self, o, **kwargs):
aptr += ['*']
else:
aptr += ['']
arguments = [f'{self.visit(a.type, **kwargs)} {p}{a.name.lower()}'
for a, p in zip(o.arguments, aptr)]
header += [self.format_line('int ', o.name, '(', self.join_items(arguments), ') {')]

"""
pass_by, var_keywords = self._subroutine_arguments(o, **kwargs)
arguments = [f'{k}{self.visit(a.type, **kwargs)} {p}{a.name.lower()}'
for a, p, k in zip(o.arguments, pass_by, var_keywords)]
opt_header = kwargs.get('header', False)
end = ' {' if not opt_header else ';'
declaration = [self.format_line('int ', o.name, '(', self.join_items(arguments), ')', end)]
return declaration

def _subroutine_body(self, o, **kwargs):
self.depth += 1

# body = ['{']
# ...and generate the spec without imports and argument declarations
body = [self.visit(o.spec, skip_imports=True, skip_argument_declarations=True, **kwargs)]

Expand All @@ -187,9 +221,37 @@ def visit_Subroutine(self, o, **kwargs):

# Close everything off
self.depth -= 1
# footer = [self.format_line('}')]
return body

def _subroutine_footer(self, o, **kwargs):
footer = [self.format_line('}')]
return footer

def visit_Subroutine(self, o, **kwargs):
"""
Format as:
...imports...
int <name>(<args>) {
...spec without imports and argument declarations...
...body...
}
"""
opt_header = kwargs.get('header', False)
opt_guards = kwargs.get('guards', False)

header = self._subroutine_header(o, **kwargs)
declaration = self._subroutine_declaration(o, **kwargs)
body = self._subroutine_body(o, **kwargs) if not opt_header else []
footer = self._subroutine_footer(o, **kwargs) if not opt_header else []

if opt_guards:
guard_name = f'{o.name.upper()}_H'
header = [self.format_line(f'#ifndef {guard_name}'), self.format_line(f'#define {guard_name}\n\n')] + header
footer += ['\n#endif']

return self.join_lines(*header, *body, *footer)
return self.join_lines(*header, '\n', *declaration, *body, *footer)

# Handler for AST base nodes

Expand Down Expand Up @@ -352,12 +414,12 @@ def visit_CallStatement(self, o, **kwargs):
"""
args = self.visit_all(o.arguments, **kwargs)
assert not o.kwarguments
return self.format_line(o.name, '(', self.join_items(args), ');')
return self.format_line(str(o.name), '(', self.join_items(args), ');')

def visit_SymbolAttributes(self, o, **kwargs): # pylint: disable=unused-argument
if isinstance(o.dtype, DerivedType):
return f'struct {o.dtype.name}'
return c_intrinsic_type(o)
return self.symgen.c_intrinsic_type(o)

def visit_TypeDef(self, o, **kwargs):
header = self.format_line('struct ', o.name.lower(), ' {')
Expand Down Expand Up @@ -402,8 +464,8 @@ def visit_MultiConditional(self, o, **kwargs):
return self.join_lines(header, *branches, footer)


def cgen(ir):
def cgen(ir, **kwargs):
"""
Generate standardized C code from one or many IR objects/trees.
"""
return CCodegen().visit(ir)
return CCodegen().visit(ir, **kwargs)
102 changes: 102 additions & 0 deletions loki/backend/cppgen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

from loki.expression import Array
from loki.types import BasicType, DerivedType
from loki.backend.cgen import CCodegen, CCodeMapper, IntrinsicTypeC

__all__ = ['cppgen', 'CppCodegen', 'CppCodeMapper', 'IntrinsicTypeCpp']


class IntrinsicTypeCpp(IntrinsicTypeC):

def c_intrinsic_type(self, _type, *args, **kwargs):
if _type.dtype == BasicType.INTEGER:
if _type.parameter:
return 'const int'
return 'int'
return super().c_intrinsic_type(_type, *args, **kwargs)

cpp_intrinsic_type = IntrinsicTypeCpp()


class CppCodeMapper(CCodeMapper):
# pylint: disable=abstract-method, unused-argument
pass


class CppCodegen(CCodegen):
"""
...
"""
standard_imports = ['stdio.h', 'stdbool.h', 'float.h', 'math.h']

def __init__(self, depth=0, indent=' ', linewidth=90, **kwargs):
symgen = kwargs.get('symgen', CppCodeMapper(cpp_intrinsic_type))
line_cont = kwargs.get('line_cont', '\n{} '.format)

super().__init__(depth=depth, indent=indent, linewidth=linewidth,
line_cont=line_cont, symgen=symgen)

def _subroutine_header(self, o, **kwargs):
header = super()._subroutine_header(o, **kwargs)
return header

def _subroutine_arguments(self, o, **kwargs):
# opt_extern = kwargs.get('extern', False)
# if opt_extern:
# return super()._subroutine_arguments(o, **kwargs)
var_keywords = []
pass_by = []
for a in o.arguments:
if isinstance(a, Array) > 0 and a.type.intent.lower() == "in":
var_keywords += ['const ']
else:
var_keywords += ['']
if isinstance(a, Array) > 0:
pass_by += ['* restrict ']
elif isinstance(a.type.dtype, DerivedType):
pass_by += ['*']
elif a.type.pointer:
pass_by += ['*']
else:
pass_by += ['']
return pass_by, var_keywords

def _subroutine_declaration(self, o, **kwargs):
opt_extern = kwargs.get('extern', False)
declaration = [self.format_line('extern "C" {\n')] if opt_extern else []
declaration += super()._subroutine_declaration(o, **kwargs)
return declaration

def _subroutine_body(self, o, **kwargs):
body = super()._subroutine_body(o, **kwargs)
return body

def _subroutine_footer(self, o, **kwargs):
opt_extern = kwargs.get('extern', False)
footer = super()._subroutine_footer(o, **kwargs)
footer += [self.format_line('\n} // extern')] if opt_extern else []
return footer

# def visit_Subroutine(self, o, **kwargs):
# """
# Format as:
# ...imports...
# int <name>(<args>) {
# ...spec without imports and argument declarations...
# ...body...
# }
# """
# return super().visit_Subroutine(o, **kwargs)


def cppgen(ir, **kwargs):
"""
...
"""
return CppCodegen().visit(ir, **kwargs)
Loading

0 comments on commit 962528e

Please sign in to comment.