Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve parse_expr and use in process_dimension_pragmas #292

Merged
merged 11 commits into from
Apr 25, 2024
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion loki/expression/expr_visitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
"""
from pymbolic.primitives import Expression

from loki.ir import Node, Visitor, Transformer
from loki.ir.nodes import Node
from loki.ir.visitor import Visitor
from loki.ir.transformer import Transformer
from loki.tools import flatten, as_tuple
from loki.expression.mappers import (
SubstituteExpressionsMapper, ExpressionRetriever, AttachScopesMapper
Expand Down
115 changes: 109 additions & 6 deletions loki/expression/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,14 @@
import re
import math
import pytools.lex
import numpy as np
from pymbolic.parser import Parser as ParserBase # , FinalizedTuple
from pymbolic.mapper import Mapper
import pymbolic.primitives as pmbl
from pymbolic.mapper.evaluator import EvaluationMapper
from pymbolic.parser import (
_openpar, _closepar, _minus, FinalizedTuple, _PREC_UNARY,
_PREC_TIMES, _PREC_PLUS, _times, _plus
_PREC_TIMES, _PREC_PLUS, _PREC_CALL, _times, _plus
)
try:
from fparser.two.Fortran2003 import Intrinsic_Name
Expand Down Expand Up @@ -109,14 +110,19 @@ def map_meta_symbol(self, expr, *args, **kwargs):
map_array = map_meta_symbol

def map_slice(self, expr, *args, **kwargs):
return sym.RangeIndex(tuple(self.rec(child, *args, **kwargs) for child in expr.children))
children = tuple(self.rec(child, *args, **kwargs) if child is not None else child for child in expr.children)
if len(children) == 1 and children[0] is None:
# this corresponds to ':' (sym.RangeIndex((None, None)))
children = (None, None)
return sym.RangeIndex(children)

map_range = map_slice
map_range_index = map_slice
map_loop_range = map_slice

def map_variable(self, expr, *args, **kwargs):
return sym.Variable(name=expr.name)
parent = kwargs.pop('parent', None)
return sym.Variable(name=expr.name, parent=parent)

def map_algebraic_leaf(self, expr, *args, **kwargs):
if str(expr).isnumeric():
Expand All @@ -127,7 +133,12 @@ def map_algebraic_leaf(self, expr, *args, **kwargs):
if expr.function.name.upper() in FORTRAN_INTRINSIC_PROCEDURES:
return sym.InlineCall(function=sym.Variable(name=expr.function.name),
parameters=tuple(self.rec(param, *args, **kwargs) for param in expr.parameters))
return sym.Variable(name=expr.function.name,
parent = kwargs.pop('parent', None)
dimensions = tuple(self.rec(param, *args, **kwargs) for param in expr.parameters)
if not dimensions:
return sym.InlineCall(function=sym.Variable(name=expr.function.name, parent=parent),
parameters=dimensions)
return sym.Variable(name=expr.function.name, parent=parent,
dimensions=tuple(self.rec(param, *args, **kwargs) for param in expr.parameters))
try:
return self.map_variable(expr, *args, **kwargs)
Expand All @@ -151,6 +162,16 @@ def map_tuple(self, expr, *args, **kwargs):
def map_list(self, expr, *args, **kwargs):
return sym.LiteralList([self.rec(elem, *args, **kwargs) for elem in expr])

def map_remainder(self, expr, *args, **kwargs):
# this should never happen as '%' is overwritten to represent derived types
raise NotImplementedError

def map_lookup(self, expr, *args, **kwargs):
# construct derived type(s) variables
parent = kwargs.pop('parent', None)
parent = self.rec(expr.aggregate, parent=parent)
return self.rec(expr.name, parent=parent)


class LokiEvaluationMapper(EvaluationMapper):
"""
Expand All @@ -163,6 +184,16 @@ class LokiEvaluationMapper(EvaluationMapper):
Raise exception for unknown symbols/expressions (default: `False`).
"""

@staticmethod
def case_insensitive_getattr(obj, attr):
"""
Case-insensitive version of `getattr`.
"""
for elem in dir(obj):
if elem.lower() == attr.lower():
return getattr(obj, elem)
return getattr(obj, attr)

def __init__(self, strict=False, **kwargs):
self.strict = strict
super().__init__(**kwargs)
Expand All @@ -183,6 +214,15 @@ def map_variable(self, expr):
return super().map_variable(expr)
return expr

@staticmethod
def _evaluate_array(arr, dims):
"""
Evaluate arrays by converting to numpy array and
adapting the dimensions corresponding to the different
starting index.
"""
return np.array(arr, order='F').item(*[dim-1 for dim in dims])

def map_call(self, expr):
if expr.function.name.lower() == 'min':
return min(self.rec(par) for par in expr.parameters)
Expand All @@ -201,8 +241,55 @@ def map_call(self, expr):
return math.sqrt(float([self.rec(par) for par in expr.parameters][0]))
if expr.function.name.lower() == 'exp':
return math.exp(float([self.rec(par) for par in expr.parameters][0]))
if expr.function.name in self.context and not callable(self.context[expr.function.name]):
return self._evaluate_array(self.context[expr.function.name],
[self.rec(par) for par in expr.parameters])
return super().map_call(expr)

def map_call_with_kwargs(self, expr):
args = [self.rec(par) for par in expr.parameters]
kwargs = {
k: self.rec(v)
for k, v in expr.kw_parameters.items()}
kwargs = CaseInsensitiveDict(kwargs)
return self.rec(expr.function)(*args, **kwargs)

def map_lookup(self, expr):

def rec_lookup(expr, obj, name):
return expr.name, self.case_insensitive_getattr(obj, name)

try:
current_expr = expr
obj = self.rec(expr.aggregate)
while isinstance(current_expr.name, pmbl.Lookup):
current_expr, obj = rec_lookup(current_expr, obj, current_expr.name.aggregate.name)
if isinstance(current_expr.name, pmbl.Variable):
_, obj = rec_lookup(current_expr, obj, current_expr.name.name)
return obj
if isinstance(current_expr.name, pmbl.Call):
name = current_expr.name.function.name
_, obj = rec_lookup(current_expr, obj, name)
if callable(obj):
return obj(*[self.rec(par) for par in current_expr.name.parameters])
return self._evaluate_array(obj, [self.rec(par) for par in current_expr.name.parameters])
if isinstance(current_expr.name, pmbl.CallWithKwargs):
name = current_expr.name.function.name
_, obj = rec_lookup(current_expr, obj, name)
args = [self.rec(par) for par in current_expr.name.parameters]
kwargs = CaseInsensitiveDict(
(k, self.rec(v))
for k, v in current_expr.name.kw_parameters.items()
)
return obj(*args, **kwargs)
except Exception as e:
reuterbal marked this conversation as resolved.
Show resolved Hide resolved
if self.strict:
raise e
return expr
if self.strict:
raise NotImplementedError
Comment on lines +289 to +290
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure I understand the control-flow here. What's not implemented when we run through cleanly without an exception?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can discuss about that. However, I thought if strict = True one wants to have an evaluated expression which in this case is not true as it would return the unchanged expression itself?! What do you think?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, you're right of course. I missed the fact that we won't reach this point if it's one of the known "lookup" nodes. Thanks for clarifying!

return expr


class ExpressionParser(ParserBase):
"""
Expand Down Expand Up @@ -273,6 +360,7 @@ class ExpressionParser(ParserBase):
_f_string = intern("f_string")
_f_openbracket = intern("openbracket")
_f_closebracket = intern("closebracket")
_f_derived_type = intern("dot")

lex_table = [
(_f_true, pytools.lex.RE(r"\.true\.", re.IGNORECASE)),
Expand All @@ -292,6 +380,7 @@ class ExpressionParser(ParserBase):
pytools.lex.RE(r"\'.*\'", re.IGNORECASE))),
(_f_openbracket, pytools.lex.RE(r"\(/")),
(_f_closebracket, pytools.lex.RE(r"/\)")),
(_f_derived_type, pytools.lex.RE(r"\%")),
] + ParserBase.lex_table
"""
Extend :any:`pymbolic.parser.Parser.lex_table` to accomodate for Fortran specifix syntax/expressions.
Expand Down Expand Up @@ -357,7 +446,12 @@ def parse_prefix(self, pstate):
def parse_postfix(self, pstate, min_precedence, left_exp):

did_something = False
if pstate.is_next(_times) and _PREC_TIMES > min_precedence:
if pstate.is_next(self._f_derived_type) and _PREC_CALL > min_precedence:
pstate.advance()
right_exp = self.parse_expression(pstate, _PREC_PLUS)
left_exp = pmbl.Lookup(left_exp, right_exp)
did_something = True
elif pstate.is_next(_times) and _PREC_TIMES > min_precedence:
pstate.advance()
right_exp = self.parse_expression(pstate, _PREC_PLUS)
# NECESSARY to ensure correct ordering!
Expand Down Expand Up @@ -433,6 +527,15 @@ def __call__(self, expr_str, scope=None, evaluate=False, strict=False, context=N
ir = PymbolicMapper()(result)
return AttachScopes().visit(ir, scope=scope or Scope())

def parse_float(self, s):
"""
Parse float literals.

Do not cast to float via 'float()' in order to keep the original
notation, e.g., do not convert 1E-3 to 0.003.
"""
return sym.FloatLiteral(value=s.replace("d", "e").replace("D", "e"))

def parse_f_float(self, s):
"""
Parse "Fortran-style" float literals.
Expand All @@ -441,7 +544,7 @@ def parse_f_float(self, s):
"""
stripped = s.split('_', 1)
if len(stripped) == 2:
return sym.Literal(value=self.parse_float(stripped[0]), kind=sym.Variable(name=stripped[1].lower()))
return sym.FloatLiteral(value=self.parse_float(stripped[0]), kind=sym.Variable(name=stripped[1].lower()))
return self.parse_float(stripped[0])

def parse_f_int(self, s):
Expand Down
Loading
Loading