Skip to content

Commit

Permalink
Merge pull request #433 from ecmwf-ifs/naml-sanitise-refactoring
Browse files Browse the repository at this point in the history
Sanitise: New transformation sub-package and some refactoring
  • Loading branch information
mlange05 authored Nov 12, 2024
2 parents 15d4a3f + 2bdbe34 commit 358fb60
Show file tree
Hide file tree
Showing 15 changed files with 710 additions and 265 deletions.
4 changes: 2 additions & 2 deletions lint_rules/lint_rules/debug_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import operator as _op
from loki import (
FindNodes, CallStatement, Assignment, Scalar, RangeIndex, resolve_associates,
FindNodes, CallStatement, Assignment, Scalar, RangeIndex, do_resolve_associates,
simplify, Sum, Product, IntLiteral, as_tuple, SubstituteExpressions, Array,
symbolic_op, StringLiteral, is_constant, LogicLiteral, VariableDeclaration, flatten,
FindInlineCalls, Conditional, FindExpressions, Comparison
Expand Down Expand Up @@ -113,7 +113,7 @@ def check_subroutine(cls, subroutine, rule_report, config, **kwargs):
max_indirections = config['max_indirections']

# first resolve associates
resolve_associates(subroutine)
do_resolve_associates(subroutine)

assign_map = {a.lhs: a.rhs for a in FindNodes(Assignment).visit(subroutine.body)}
decl_symbols = flatten([decl.symbols for decl in FindNodes(VariableDeclaration).visit(subroutine.spec)])
Expand Down
4 changes: 2 additions & 2 deletions loki/transformations/block_index_transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from loki.expression import (
symbols as sym, Variable, Array, RangeIndex
)
from loki.transformations.sanitise import resolve_associates
from loki.transformations.sanitise import do_resolve_associates
from loki.transformations.utilities import (
recursive_expression_map_update, get_integer_variable,
get_loop_bounds, check_routine_sequential
Expand Down Expand Up @@ -242,7 +242,7 @@ def process_body(self, body, definitions, successors, targets, exclude_arrays):
def process_kernel(self, routine, item, successors, targets, exclude_arrays):

# Sanitize the subroutine
resolve_associates(routine)
do_resolve_associates(routine)
v_index = get_integer_variable(routine, name=self.horizontal.index)
SCCBaseTransformation.resolve_masked_stmts(routine, loop_variable=v_index)

Expand Down
16 changes: 9 additions & 7 deletions loki/transformations/inline/procedures.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from loki.logging import error
from loki.subroutine import Subroutine

from loki.transformations.sanitise import transform_sequence_association_append_map
from loki.transformations.sanitise import SequenceAssociationTransformer
from loki.transformations.utilities import (
single_variable_declaration, recursive_expression_map_update
)
Expand All @@ -37,9 +37,9 @@ def resolve_sequence_association_for_inlined_calls(routine, inline_internals, in
or in calls to procedures that have been marked with an inline pragma (if ``inline_marked = True``).
If both ``inline_internals`` and ``inline_marked`` are ``False``, no processing is done.
"""
call_map = {}
with pragmas_attached(routine, node_type=CallStatement):
for call in FindNodes(CallStatement).visit(routine.body):
class SequenceAssociationForInlineCallsTransformer(SequenceAssociationTransformer):

def visit_CallStatement(self, call, **kwargs):
condition = (
(inline_marked and is_loki_pragma(call.pragma, starts_with='inline')) or
(inline_internals and call.routine in routine.routines)
Expand All @@ -56,9 +56,11 @@ def resolve_sequence_association_for_inlined_calls(routine, inline_internals, in
"the source code of the procedure. " +
"If running in batch processing mode, please recheck Scheduler configuration."
)
transform_sequence_association_append_map(call_map, call)
if call_map:
routine.body = Transformer(call_map).visit(routine.body)

return super().visit_CallStatement(call, **kwargs)

with pragmas_attached(routine, node_type=CallStatement):
routine.body = SequenceAssociationForInlineCallsTransformer(inplace=True).visit(routine.body)


def map_call_to_procedure_body(call, caller, callee=None):
Expand Down
95 changes: 95 additions & 0 deletions loki/transformations/sanitise/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.
"""
Sub-package with assorted utility :any:`Transformation` classes to
harmonize the look-and-feel of input source code.
"""
from functools import partial

from loki.batch import Transformation, Pipeline

from loki.transformations.sanitise.associates import * # noqa
from loki.transformations.sanitise.sequence_associations import * # noqa
from loki.transformations.sanitise.substitute import * # noqa


"""
:any:`Pipeline` class that provides combined access to the features
provided by the following :any:`Transformation` classes, in sequence:
1. :any:`SubstituteExpressionTransformation` - String-based generic
expression substitution.
2. :any:`AssociatesTransformation` - Full or partial resolution of
nested :any:`Associate` nodes, including optional merging of
independent association pairs.
3. :any:`SequenceAssociationTransformation` - Resolves sequence
association patterns in the call signature of :any:`CallStatement`
nodes.
Parameters
----------
substitute_expressions : bool
Flag to trigger or suppress expression substitution
expression_map : dict of str to str
A string-to-string map detailing the substitutions to apply.
substitute_spec : bool
Flag to trigger or suppress expression substitution in specs.
substitute_body : bool
Flag to trigger or suppress expression substitution in bodies.
resolve_associates : bool, default: True
Enable full or partial resolution of only :any:`Associate`
scopes.
merge_associates : bool, default: False
Enable merging :any:`Associate` to the outermost possible
scope in nested associate blocks.
start_depth : int, optional
Starting depth for partial resolution of :any:`Associate`
after merging.
max_parents : int, optional
Maximum number of parent symbols for valid selector to have
when merging :any:`Associate` nodes.
resolve_sequence_associations : bool
Flag to trigger or suppress resolution of sequence associations
"""
SanitisePipeline = partial(
Pipeline, classes=(
SubstituteExpressionTransformation,
AssociatesTransformation,
SequenceAssociationTransformation,
)
)


class SanitiseTransformation(Transformation):
"""
:any:`Transformation` object to apply several code sanitisation
steps when batch-processing large source trees via the :any:`Scheduler`.
Parameters
----------
resolve_associate_mappings : bool
Resolve ASSOCIATE mappings in body of processed subroutines; default: True.
resolve_sequence_association : bool
Replace scalars that are passed to array arguments with array
ranges; default: False.
"""

def __init__(
self, resolve_associate_mappings=True, resolve_sequence_association=False
):
self.resolve_associate_mappings = resolve_associate_mappings
self.resolve_sequence_association = resolve_sequence_association

def transform_subroutine(self, routine, **kwargs):

# Associates at the highest level, so they don't interfere
# with the sections we need to do for detecting subroutine calls
if self.resolve_associate_mappings:
do_resolve_associates(routine)

# Transform arrays passed with scalar syntax to array syntax
if self.resolve_sequence_association:
do_resolve_sequence_association(routine)
Original file line number Diff line number Diff line change
Expand Up @@ -13,53 +13,69 @@
"""

from loki.batch import Transformation
from loki.expression import Array, RangeIndex, LokiIdentityMapper
from loki.ir import nodes as ir, FindNodes, Transformer, NestedTransformer
from loki.expression import LokiIdentityMapper
from loki.ir import nodes as ir, Transformer, NestedTransformer
from loki.scope import SymbolTable
from loki.tools import as_tuple, dict_override
from loki.types import BasicType
from loki.tools import dict_override


__all__ = [
'SanitiseTransformation', 'resolve_associates', 'merge_associates',
'ResolveAssociatesTransformer', 'transform_sequence_association',
'transform_sequence_association_append_map'
'AssociatesTransformation', 'do_resolve_associates',
'ResolveAssociatesTransformer', 'do_merge_associates'
]


class SanitiseTransformation(Transformation):
class AssociatesTransformation(Transformation):
"""
:any:`Transformation` object to apply several code sanitisation
steps when batch-processing large source trees via the :any:`Scheduler`.
:any:`Transformation` object to apply code sanitisation steps
specific to :any:`Associate` nodes.
It allows merging in nested :any:`Associate` scopes to move
independent assocation pairs to the outermost scope, optionally
restricted by a number of ``max_parents`` symbols.
It also provides partial or full resolution of :any:`Associate`
nodes by replacing ``identifier`` symbols with the corresponding
``selector`` in the node's body.
Parameters
----------
resolve_associate_mappings : bool
Resolve ASSOCIATE mappings in body of processed subroutines; default: True.
resolve_sequence_association : bool
Replace scalars that are passed to array arguments with array
ranges; default: False.
resolve_associates : bool, default: True
Enable full or partial resolution of only :any:`Associate`
scopes.
merge_associates : bool, default: False
Enable merging :any:`Associate` to the outermost possible
scope in nested associate blocks.
start_depth : int, optional
Starting depth for partial resolution of :any:`Associate`
after merging.
max_parents : int, optional
Maximum number of parent symbols for valid selector to have
when merging :any:`Associate` nodes.
"""

def __init__(
self, resolve_associate_mappings=True, resolve_sequence_association=False
self, resolve_associates=True, merge_associates=False,
start_depth=0, max_parents=None
):
self.resolve_associate_mappings = resolve_associate_mappings
self.resolve_sequence_association = resolve_sequence_association
self.resolve_associates = resolve_associates
self.merge_associates = merge_associates

self.start_depth = start_depth
self.max_parents = max_parents

def transform_subroutine(self, routine, **kwargs):

# Associates at the highest level, so they don't interfere
# with the sections we need to do for detecting subroutine calls
if self.resolve_associate_mappings:
resolve_associates(routine)
# Merge associates first so that remainig ones can be resolved
if self.merge_associates:
do_merge_associates(routine, max_parents=self.max_parents)

# Transform arrays passed with scalar syntax to array syntax
if self.resolve_sequence_association:
transform_sequence_association(routine)
# Resolve remaining associates depending on start_depth
if self.resolve_associates:
do_resolve_associates(routine, start_depth=self.start_depth)


def resolve_associates(routine, start_depth=0):
def do_resolve_associates(routine, start_depth=0):
"""
Resolve :any:`Associate` mappings in the body of a given routine.
Expand Down Expand Up @@ -149,8 +165,8 @@ class ResolveAssociatesTransformer(Transformer):
:any:`Transformer` class to resolve :any:`Associate` nodes in IR trees.
This will replace each :any:`Associate` node with its own body,
where all `identifier` symbols have been replaced with the
corresponding `selector` expression defined in ``associations``.
where all ``identifier`` symbols have been replaced with the
corresponding ``selector`` expression defined in ``associations``.
Importantly, this :any:`Transformer` can also be applied over partial
bodies of :any:`Associate` bodies.
Expand Down Expand Up @@ -195,7 +211,7 @@ def visit_CallStatement(self, o, **kwargs):
return o._rebuild(arguments=arguments, kwarguments=kwarguments)


def merge_associates(routine, max_parents=None):
def do_merge_associates(routine, max_parents=None):
"""
Moves associate mappings in :any:`Associate` within a
:any:`Subroutine` to the outermost parent scope.
Expand Down Expand Up @@ -280,95 +296,3 @@ def visit_Associate(self, o, **kwargs):
# that moved associations get the correct defining scope
o._derive_local_symbol_types(parent_scope=o.parent)
return o


def check_if_scalar_syntax(arg, dummy):
"""
Check if an array argument, arg,
is passed to an array dummy argument, dummy,
using scalar syntax. i.e. arg(1,1) -> d(m,n)
Parameters
----------
arg: variable
dummy: variable
"""
if isinstance(arg, Array) and isinstance(dummy, Array):
if arg.dimensions:
if not any(isinstance(d, RangeIndex) for d in arg.dimensions):
return True
return False


def transform_sequence_association(routine):
"""
Housekeeping routine to replace scalar syntax when passing arrays as arguments
For example, a call like
.. code-block::
real :: a(m,n)
call myroutine(a(i,j))
where myroutine looks like
.. code-block::
subroutine myroutine(a)
real :: a(5)
end subroutine myroutine
should be changed to
.. code-block::
call myroutine(a(i:m,j)
Parameters
----------
routine : :any:`Subroutine`
The subroutine where calls will be changed
"""

#List calls in routine, but make sure we have the called routine definition
calls = (c for c in FindNodes(ir.CallStatement).visit(routine.body) if not c.procedure_type is BasicType.DEFERRED)
call_map = {}

# Check all calls and record changes to `call_map` if necessary.
for call in calls:
transform_sequence_association_append_map(call_map, call)

# Fix sequence association in all calls in one go.
if call_map:
routine.body = Transformer(call_map).visit(routine.body)

def transform_sequence_association_append_map(call_map, call):
"""
Check if `call` contains the sequence association pattern in one of the arguments,
and if so, add the necessary transform data to `call_map`.
"""
new_args = []
found_scalar = False
for dummy, arg in call.arg_map.items():
if check_if_scalar_syntax(arg, dummy):
found_scalar = True

n_dims = len(dummy.shape)
new_dims = []
for s, lower in zip(arg.shape[:n_dims], arg.dimensions[:n_dims]):

if isinstance(s, RangeIndex):
new_dims += [RangeIndex((lower, s.stop))]
else:
new_dims += [RangeIndex((lower, s))]

if len(arg.dimensions) > n_dims:
new_dims += arg.dimensions[len(dummy.shape):]
new_args += [arg.clone(dimensions=as_tuple(new_dims)),]
else:
new_args += [arg,]

if found_scalar:
call_map[call] = call.clone(arguments = as_tuple(new_args))
Loading

0 comments on commit 358fb60

Please sign in to comment.