Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Correct symbols after derived type enrichment #450

Merged
merged 8 commits into from
Nov 29, 2024
4 changes: 4 additions & 0 deletions loki/expression/mappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def map_variable_symbol(self, expr, enclosing_prec, *args, **kwargs):

map_deferred_type_symbol = map_variable_symbol
map_procedure_symbol = map_variable_symbol
map_derived_type_symbol = map_variable_symbol

def map_meta_symbol(self, expr, enclosing_prec, *args, **kwargs):
return self.rec(expr._symbol, enclosing_prec, *args, **kwargs)
Expand Down Expand Up @@ -234,6 +235,7 @@ def map_variable_symbol(self, expr, *args, **kwargs):

map_deferred_type_symbol = map_variable_symbol
map_procedure_symbol = map_variable_symbol
map_derived_type_symbol = map_variable_symbol

def map_meta_symbol(self, expr, *args, **kwargs):
if not self.visit(expr):
Expand Down Expand Up @@ -611,6 +613,7 @@ def map_variable_symbol(self, expr, *args, **kwargs):

map_deferred_type_symbol = map_variable_symbol
map_procedure_symbol = map_variable_symbol
map_derived_type_symbol = map_variable_symbol

def map_meta_symbol(self, expr, *args, **kwargs):
symbol = self.rec(expr._symbol, *args, **kwargs)
Expand Down Expand Up @@ -823,6 +826,7 @@ def map_procedure_symbol(self, expr, *args, **kwargs):
return expr.clone(scope=kwargs['scope'])
return self.map_variable_symbol(expr, *args, **kwargs)


class DetachScopesMapper(LokiIdentityMapper):
"""
A Pymbolic expression mapper (i.e., a visitor for the expression tree)
Expand Down
39 changes: 34 additions & 5 deletions loki/expression/symbols.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
# Mix-ins
'StrCompareMixin',
# Typed leaf nodes
'TypedSymbol', 'DeferredTypeSymbol', 'VariableSymbol', 'ProcedureSymbol',
'TypedSymbol', 'DeferredTypeSymbol', 'VariableSymbol', 'ProcedureSymbol', 'DerivedTypeSymbol',
'MetaSymbol', 'Scalar', 'Array', 'Variable',
# Non-typed leaf nodes
'FloatLiteral', 'IntLiteral', 'LogicLiteral', 'StringLiteral',
Expand Down Expand Up @@ -481,6 +481,34 @@ def __init__(self, name, scope=None, type=None, **kwargs):
mapper_method = intern('map_procedure_symbol')


class DerivedTypeSymbol(StrCompareMixin, TypedSymbol, _FunctionSymbol):
"""
Internal representation of a symbol that represents a named
derived type.

This is used to represent the derived type symbolically in
:any:`Import` statements and when defining derived types.

Parameters
----------
name : str
The name of the symbol.
scope : :any:`Scope`
The scope in which the symbol is declared.
type : optional
The type of that symbol. Defaults to :any:`BasicType.DEFERRED`.
"""

def __init__(self, name, scope=None, type=None, **kwargs):
# pylint: disable=redefined-builtin
assert type is None or isinstance(type.dtype, DerivedType)
if type is not None:
assert name.lower() == type.dtype.name.lower()
super().__init__(name=name, scope=scope, type=type, **kwargs)

mapper_method = intern('map_derived_type_symbol')


class MetaSymbol(StrCompareMixin, pmbl.AlgebraicLeaf):
"""
Base class for meta symbols to encapsulate a symbol node with optional
Expand Down Expand Up @@ -868,9 +896,8 @@ def __new__(cls, **kwargs):
return ProcedureSymbol(**kwargs)

if _type and isinstance(_type.dtype, DerivedType) and name.lower() == _type.dtype.name.lower():
# This is a constructor call (or a type imported in an ``IMPORT`` statement, in which
# case this is classified wrong...)
return ProcedureSymbol(**kwargs)
# This the name of a derived type, as found in USE import statements
return DerivedTypeSymbol(**kwargs)

if 'dimensions' in kwargs and kwargs['dimensions'] is None:
# Convenience: This way we can construct Scalar variables with `dimensions=None`
Expand Down Expand Up @@ -1315,7 +1342,9 @@ def __init__(self, function, parameters=None, kw_parameters=None, **kwargs):
# Unfortunately, have to accept MetaSymbol here for the time being as
# rescoping before injecting statement functions may create InlineCalls
# with Scalar/Variable function names.
assert isinstance(function, (ProcedureSymbol, DeferredTypeSymbol, MetaSymbol))
assert isinstance(function, (
ProcedureSymbol, DerivedTypeSymbol, DeferredTypeSymbol, MetaSymbol
))
parameters = parameters or ()
kw_parameters = kw_parameters or {}

Expand Down
40 changes: 20 additions & 20 deletions loki/frontend/fparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -1780,6 +1780,26 @@ def visit_Subroutine_Subprogram(self, o, **kwargs):
# symbols in the spec part to make them coherent with the symbol table
spec = AttachScopes().visit(spec, scope=routine, recurse_to_declaration_attributes=True)

# To simplify things, we always declare the result-type of a function with
# a declaration in the spec as this can capture every possible situation.
# Therefore, if it has been declared as a prefix in the subroutine statement,
# we now have to inject a declaration instead. To ensure we do this in the
# right place in the spec to not violate the intrinsic order Fortran mandates,
# we search for the first occurence of any VariableDeclaration or
# ProcedureDeclaration and inject it before that one
if return_type is not None:
routine.symbol_attrs[routine.name] = return_type
return_var = sym.Variable(name=routine.name, scope=routine)
decl_source = self.get_source(subroutine_stmt, source=None)
return_var_decl = ir.VariableDeclaration(symbols=(return_var,), source=decl_source)

decls = FindNodes((ir.VariableDeclaration, ir.ProcedureDeclaration)).visit(spec)
if not decls:
# No other declarations: add it to the end
spec.append(return_var_decl)
else:
spec.insert(spec.body.index(decls[0]), return_var_decl)

# Now all declarations are well-defined and we can parse the member routines
if contains_ast is not None:
contains = self.visit(contains_ast, **kwargs)
Expand Down Expand Up @@ -1821,26 +1841,6 @@ def visit_Subroutine_Subprogram(self, o, **kwargs):
comment_map[node] = None
spec = Transformer(comment_map, invalidate_source=False).visit(spec)

# To simplify things, we always declare the result-type of a function with
# a declaration in the spec as this can capture every possible situation.
# Therefore, if it has been declared as a prefix in the subroutine statement,
# we now have to inject a declaration instead. To ensure we do this in the
# right place in the spec to not violate the intrinsic order Fortran mandates,
# we search for the first occurence of any VariableDeclaration or
# ProcedureDeclaration and inject it before that one
if return_type is not None:
routine.symbol_attrs[routine.name] = return_type
return_var = sym.Variable(name=routine.name, scope=routine)
decl_source = self.get_source(subroutine_stmt, source=None)
return_var_decl = ir.VariableDeclaration(symbols=(return_var,), source=decl_source)

decls = FindNodes((ir.VariableDeclaration, ir.ProcedureDeclaration)).visit(spec)
if not decls:
# No other declarations: add it to the end
spec.append(return_var_decl)
else:
spec.insert(spec.body.index(decls[0]), return_var_decl)

# Finally, call the subroutine constructor on the object again to register all
# bits and pieces in place and rescope all symbols
# pylint: disable=unnecessary-dunder-call
Expand Down
Loading
Loading