Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid variable_map in TypedSymbol.get_derived_type_member and verify type information is derived correctly #285

Merged
merged 2 commits into from
Apr 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions loki/expression/symbols.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,14 +361,10 @@ def get_derived_type_member(self, name_str):
"""
Resolve type-bound variables of arbitrary nested depth.
"""

name_parts = name_str.split('%', maxsplit=1)
if not (declared_var := self.variable_map.get(name_parts[0], None)):
declared_var = Variable(name=name_parts[0], scope=self.scope, parent=self)

declared_var = Variable(name=f'{self.name}%{name_parts[0]}', scope=self.scope, parent=self)
if len(name_parts) > 1:
declared_var = declared_var.get_derived_type_member(name_parts[1])

return declared_var.get_derived_type_member(name_parts[1]) # pylint:disable=no-member
return declared_var


Expand Down
75 changes: 75 additions & 0 deletions tests/test_expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -1570,3 +1570,78 @@ def test_typebound_resolution(expr):

assert var == expr
assert var.scope == scope


@pytest.mark.parametrize('frontend', available_frontends(
skip={OMNI: "OMNI fails on missing module"}
))
def test_typebound_resolution_type_info(frontend):
fcode = """
module typebound_resolution_type_info_mod
use some_mod, only: tt
implicit none
type t_a
logical :: a
end type t_a

type t_b
type(t_a) :: b_a
integer :: b
end type t_b

type t_c
type(t_b) :: c_b
real :: c
end type t_c
contains
subroutine sub ()
type(t_c) :: var_c
type(tt) :: var_tt
end subroutine sub
end module typebound_resolution_type_info_mod
""".strip()

module = Module.from_source(fcode, frontend=frontend)

sub = module['sub']
var_c = sub.variable_map['var_c']
var_tt = sub.variable_map['var_tt']

t_a = module['t_a']
t_b = module['t_b']

var_c_to_try = {
'c': BasicType.REAL,
'c_b': t_b.dtype,
'c_b%b': BasicType.INTEGER,
'c_b%b_a': t_a.dtype,
'c_b%b_a%a': BasicType.LOGICAL,
}

var_tt_to_try = {
'some': BasicType.DEFERRED,
'some%member': BasicType.DEFERRED
}

# Make sure none of the derived type members exist
# in the symbol table initially
for var_name in var_c_to_try:
assert f'var_c%{var_name}' not in sub.symbol_attrs

for var_name in var_tt_to_try:
assert f'var_tt%{var_name}' not in sub.symbol_attrs

# Create each derived type member and verify its type
for var_name, dtype in var_c_to_try.items():
var = var_c.get_derived_type_member(var_name)
assert var == f'var_c%{var_name}'
assert var.scope is sub
assert isinstance(var, symbols.Scalar)
assert var.type.dtype == dtype

for var_name, dtype in var_tt_to_try.items():
var = var_tt.get_derived_type_member(var_name)
assert var == f'var_tt%{var_name}'
assert var.scope is sub
assert isinstance(var, symbols.DeferredTypeSymbol)
assert var.type.dtype == dtype
Loading