Skip to content

Commit

Permalink
Merge pull request #419 from ecmwf-ifs/naml-expression-mapper-tests
Browse files Browse the repository at this point in the history
Expression: Expression cloning and mapper tests
  • Loading branch information
reuterbal authored Nov 7, 2024
2 parents 13aba08 + 193742c commit f731b0b
Show file tree
Hide file tree
Showing 12 changed files with 262 additions and 69 deletions.
67 changes: 23 additions & 44 deletions loki/expression/mappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
Mappers for traversing and transforming the
:ref:`internal_representation:Expression tree`.
"""
from copy import deepcopy

import re
from itertools import zip_longest
import pymbolic.primitives as pmbl
Expand Down Expand Up @@ -515,37 +515,23 @@ class LokiIdentityMapper(IdentityMapper):
This can serve as basis for any transformation mappers
that apply changes to the expression tree. Expression nodes that
are unchanged are returned as is.
Parameters
----------
invalidate_source : bool, optional
By default the :attr:`source` property of nodes is discarded
when rebuilding the node, setting this to `False` allows to
retain that information
"""

def __init__(self, invalidate_source=True):
super().__init__()
self.invalidate_source = invalidate_source
@staticmethod
def _rebuild(expr):
""" Utility to safely rebuild any symbol """
if hasattr(expr, 'clone'):
return expr.clone()

# Re-create symbol Pymbolic-style
cargs = dict(zip(expr.init_arg_names, expr.__getinitargs__()))
return type(expr)(**cargs)

def __call__(self, expr, *args, **kwargs):
if expr is None:
return None
kwargs.setdefault('recurse_to_declaration_attributes', False)
new_expr = super().__call__(expr, *args, **kwargs)
if getattr(expr, 'source', None):
if isinstance(new_expr, tuple):
for e in new_expr:
if self.invalidate_source:
e.source = None
else:
e.source = deepcopy(expr.source)
else:
if self.invalidate_source:
new_expr.source = None
else:
new_expr.source = deepcopy(expr.source)
return new_expr
return super().__call__(expr, *args, **kwargs)

rec = __call__

Expand All @@ -557,7 +543,7 @@ def __call__(self, expr, *args, **kwargs):
def map_int_literal(self, expr, *args, **kwargs):
kind = self.rec(expr.kind, *args, **kwargs)
if kind is expr.kind:
return expr
return self._rebuild(expr)
return expr.__class__(expr.value, kind=kind)

map_float_literal = map_int_literal
Expand Down Expand Up @@ -615,11 +601,11 @@ def map_variable_symbol(self, expr, *args, **kwargs):
parent = self.rec(expr.parent, *args, **kwargs)
if expr.scope is None:
if parent is expr.parent and not is_type_changed:
return expr
return self._rebuild(expr)
return expr.clone(parent=parent, type=new_type)

if parent is expr.parent:
return expr
return self._rebuild(expr)
return expr.clone(parent=parent)

map_deferred_type_symbol = map_variable_symbol
Expand All @@ -631,7 +617,7 @@ def map_meta_symbol(self, expr, *args, **kwargs):
# but with no rebuilt it may return VariableSymbol. Therefore we need to return the
# original expression if the underlying symbol is unchanged
if symbol is expr._symbol:
return expr
return self._rebuild(expr)
return symbol

map_scalar = map_meta_symbol
Expand Down Expand Up @@ -659,7 +645,7 @@ def map_array(self, expr, *args, **kwargs):
if (getattr(symbol, 'symbol', symbol) is expr.symbol and
all(d is orig_d for d, orig_d in zip_longest(dimensions or (), expr.dimensions or ())) and
all(d is orig_d for d, orig_d in zip_longest(shape or (), symbol.type.shape or ()))):
return expr
return self._rebuild(expr)
return symbol.clone(dimensions=dimensions, type=symbol.type.clone(shape=shape), parent=parent)

def map_array_subscript(self, expr, *args, **kwargs):
Expand All @@ -678,14 +664,14 @@ def map_cast(self, expr, *args, **kwargs):
kind = self.rec(expr.kind, *args, **kwargs)
if (function is expr.function and kind is expr.kind and
all(p is orig_p for p, orig_p in zip_longest(parameters, expr.parameters))):
return expr
return self._rebuild(expr)
return expr.__class__(function, parameters, kind=kind)

def map_sum(self, expr, *args, **kwargs):
# Need to re-implement to avoid application of flattened_sum/flattened_product
children = self.rec(expr.children, *args, **kwargs)
if all(c is orig_c for c, orig_c in zip_longest(children, expr.children)):
return expr
return self._rebuild(expr)
return expr.__class__(children)

def map_quotient(self, expr, *args, **kwargs):
Expand All @@ -707,7 +693,7 @@ def map_literal_list(self, expr, *args, **kwargs):
values = tuple(v if isinstance(v, str) else self.rec(v, *args, **kwargs)
for v in expr.elements)
if all(v is orig_v for v, orig_v in zip_longest(values, expr.elements)):
return expr
return self._rebuild(expr)
return expr.__class__(values, dtype=expr.dtype)

def map_inline_do(self, expr, *args, **kwargs):
Expand Down Expand Up @@ -750,15 +736,11 @@ class SubstituteExpressionsMapper(LokiIdentityMapper):
----------
expr_map : dict
Expression mapping to apply to the expression tree.
invalidate_source : bool, optional
By default the :attr:`source` property of nodes is discarded
when rebuilding the node, setting this to `False` allows to
retain that information
"""
# pylint: disable=abstract-method

def __init__(self, expr_map, invalidate_source=True):
super().__init__(invalidate_source=invalidate_source)
def __init__(self, expr_map):
super().__init__()

self.expr_map = expr_map
for expr in self.expr_map.keys():
Expand All @@ -770,7 +752,7 @@ def map_from_expr_map(self, expr, *args, **kwargs):
otherwise continue tree traversal
"""
if expr in self.expr_map:
return self.expr_map[expr]
return self._rebuild(self.expr_map[expr])
map_fn = getattr(super(), expr.mapper_method)
return map_fn(expr, *args, **kwargs)

Expand All @@ -789,7 +771,7 @@ class AttachScopesMapper(LokiIdentityMapper):
"""

def __init__(self, fail=False):
super().__init__(invalidate_source=False)
super().__init__()
self.fail = fail

def _update_symbol_scope(self, expr, scope):
Expand Down Expand Up @@ -847,9 +829,6 @@ class DetachScopesMapper(LokiIdentityMapper):
analysis passes.
"""

def __init__(self):
super().__init__(invalidate_source=False)

def map_variable_symbol(self, expr, *args, **kwargs):
new_expr = super().map_variable_symbol(expr, *args, **kwargs)
new_expr = new_expr.clone(scope=None)
Expand Down
4 changes: 2 additions & 2 deletions loki/expression/symbolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,8 +518,8 @@ class SimplifyMapper(LokiIdentityMapper):
"""
# pylint: disable=abstract-method

def __init__(self, enabled_simplifications=Simplification.ALL, invalidate_source=True):
super().__init__(invalidate_source=invalidate_source)
def __init__(self, enabled_simplifications=Simplification.ALL):
super().__init__()

self.enabled_simplifications = enabled_simplifications

Expand Down
11 changes: 11 additions & 0 deletions loki/expression/symbols.py
Original file line number Diff line number Diff line change
Expand Up @@ -1236,6 +1236,8 @@ class LiteralList(pmbl.AlgebraicLeaf):
A list of constant literals, e.g., as used in Array Initialization Lists.
"""

init_arg_names = ('values', 'dtype')

def __init__(self, values, dtype=None, **kwargs):
self.elements = values
self.dtype = dtype
Expand Down Expand Up @@ -1424,17 +1426,26 @@ class Cast(pmbl.Call):
Internal representation of a data type cast.
"""

init_arg_names = ('name', 'expression', 'kind')

def __init__(self, name, expression, kind=None, **kwargs):
assert kind is None or isinstance(kind, pmbl.Expression)
self.kind = kind
super().__init__(pmbl.make_variable(name), as_tuple(expression), **kwargs)

def __getinitargs__(self):
return (self.name, self.expression, self.kind)

mapper_method = intern('map_cast')

@property
def name(self):
return self.function.name

@property
def expression(self):
return self.parameters


class Range(StrCompareMixin, pmbl.Slice):
"""
Expand Down
1 change: 0 additions & 1 deletion loki/expression/tests/test_expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import numpy as np

import pymbolic.primitives as pmbl
import pymbolic.mapper as pmbl_mapper

from loki import (
Sourcefile, Subroutine, Module, Scope, BasicType,
Expand Down
127 changes: 127 additions & 0 deletions loki/expression/tests/test_mapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

import pytest

from loki import Subroutine, Scope
from loki.expression import symbols as sym, parse_expr
from loki.expression.mappers import (
ExpressionRetriever, LokiIdentityMapper, SubstituteExpressionsMapper
)
from loki.frontend import available_frontends
from loki.ir import nodes as ir, FindNodes


@pytest.mark.parametrize('frontend', available_frontends())
def test_expression_retriever(frontend):
""" Test for :any:`ExpressionRetriever` (a :any:`LokiWalkMapper`) """

fcode = """
subroutine test_expr_retriever(n, a, b, c)
integer, intent(inout) :: n, a, b(n), c
a = 5 * a + 4 * b(c) + a
end subroutine test_expr_retriever
"""
routine = Subroutine.from_source(fcode, frontend=frontend)
expr = FindNodes(ir.Assignment).visit(routine.body)[0].rhs

def q_symbol(n):
return isinstance(n, sym.TypedSymbol)

def q_array(n):
return isinstance(n, sym.Array)

def q_scalar(n):
return isinstance(n, sym.Scalar)

def q_deferred(n):
return isinstance(n, sym.DeferredTypeSymbol)

def q_literal(n):
return isinstance(n, sym.IntLiteral)

assert ExpressionRetriever(q_symbol).retrieve(expr) == ['a', 'b', 'c', 'a']
assert ExpressionRetriever(q_array).retrieve(expr) == ['b(c)']
assert ExpressionRetriever(q_scalar).retrieve(expr) == ['a', 'c', 'a']
assert ExpressionRetriever(q_literal).retrieve(expr) == [5, 4]

scope = Scope()
expr = parse_expr('5 * a + 4 * b(c) + a', scope=scope)

assert ExpressionRetriever(q_symbol).retrieve(expr) == ['a', 'b', 'c', 'a']
assert ExpressionRetriever(q_array).retrieve(expr) == ['b(c)']
# Cannot determine Scalar without declarations, so check for deferred
assert ExpressionRetriever(q_deferred).retrieve(expr) == ['a', 'c', 'a']
assert ExpressionRetriever(q_literal).retrieve(expr) == [5, 4]


@pytest.mark.parametrize('frontend', available_frontends())
def test_identity_mapper(frontend):
"""
Test for :any:`LokiIdentityMapper`, in particular deep-copying
expression nodes.
"""

fcode = """
subroutine test_expr_retriever(n, a, b, c)
integer, intent(inout) :: n, a, b(n), c
a = 5 * a + 4 * b(c) + a
end subroutine test_expr_retriever
"""
routine = Subroutine.from_source(fcode, frontend=frontend)
expr = FindNodes(ir.Assignment).visit(routine.body)[0].rhs

# Run the identity mapper over the expression
new_expr = LokiIdentityMapper()(expr)

# Check that symbols and literals are equivalent, but distinct objects!
get_symbols = ExpressionRetriever(lambda e: isinstance(e, sym.TypedSymbol)).retrieve
get_literals = ExpressionRetriever(lambda e: isinstance(e, sym.IntLiteral)).retrieve

for old, new in zip(get_symbols(expr), get_symbols(new_expr)):
assert old == new
assert not old is new

for old, new in zip(get_literals(expr), get_literals(new_expr)):
assert old == new
assert not old is new


@pytest.mark.parametrize('frontend', available_frontends())
def test_substitute_expression_mapper(frontend):
"""
Test for :any:`SubstituteExpressionsMapper`.
"""

fcode = """
subroutine test_expr_retriever(n, a, b, c, d)
integer, intent(inout) :: n, a, b(n), c, d
a = 5 * a + 4 * b(c) + a
end subroutine test_expr_retriever
"""
routine = Subroutine.from_source(fcode, frontend=frontend)
expr = FindNodes(ir.Assignment).visit(routine.body)[0].rhs

retriever = ExpressionRetriever(lambda e: isinstance(e, sym.TypedSymbol))
symbols = retriever.retrieve(expr)
assert symbols == ['a', 'b', 'c', 'a']
assert symbols[0] == symbols[3]
assert not symbols[0] is symbols[3]
a = symbols[0]
d = routine.variable_map['d']

new_expr = SubstituteExpressionsMapper(expr_map={a: d})(expr)

assert new_expr == '5*d + 4*b(c) + d'
new_symbols = retriever.retrieve(new_expr)
assert new_symbols == ['d', 'b', 'c', 'd']
assert new_symbols[0] == new_symbols[3]
# Ensure multiple inserted symbols are still unique
assert not new_symbols[0] is new_symbols[3]
2 changes: 1 addition & 1 deletion loki/expression/tests/test_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from loki import Subroutine, Module, Scope
from loki.expression import symbols as sym, parse_expr
from loki.frontend import (
available_frontends, OMNI, HAVE_FP, parse_fparser_expression
available_frontends, HAVE_FP, parse_fparser_expression
)
from loki.ir import FindVariables

Expand Down
Loading

0 comments on commit f731b0b

Please sign in to comment.