Skip to content

Commit

Permalink
Angry linting gods
Browse files Browse the repository at this point in the history
  • Loading branch information
awnawab committed May 11, 2024
1 parent 8cf2cce commit cd8eed3
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 6 deletions.
2 changes: 1 addition & 1 deletion loki/program_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ def enrich(self, definitions, recurse=False):
if decl_map:
# DataType's are not stored directly in the SymbolTable, thus updating the
# imported symbols is not enough to update the type.dtype of expression nodes.
# Therefore we must also update the variable declaration with the updated symbol
# Therefore we must also update the variable declaration with the updated symbol
for decl, symbol_map in decl_map.items():
_symbols = SubstituteExpressions(symbol_map).visit(decl.symbols)
decl._update(symbols=as_tuple(_symbols))
Expand Down
11 changes: 6 additions & 5 deletions loki/transformations/block_index_transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@
from loki.tools import as_tuple
from loki.types import SymbolAttributes, BasicType
from loki.expression import Variable, Array, RangeIndex, FindVariables, SubstituteExpressions
from loki.transformations import resolve_associates, recursive_expression_map_update
from loki.transformations.single_column import SCCBaseTransformation
from loki.transformations.sanitise import resolve_associates
from loki.transformations.utilities import recursive_expression_map_update
from loki.transformations.single_column.base import SCCBaseTransformation

__all__ = ['BlockViewToFieldViewTransformation', 'InjectBlockIndexTransformation']

Expand Down Expand Up @@ -193,7 +194,7 @@ def build_ydvars_global_gfl_ptr(self, var):
return var.clone(name=var.name.upper().replace('GFL_PTR', 'GFL_PTR_G'),
parent=parent, type=_type)

def process_body(self, body, symbol_map, definitions, successors, targets, exclude_arrays):
def process_body(self, body, definitions, successors, targets, exclude_arrays):

# build list of type-bound array access using the horizontal index
_vars = [var for var in FindVariables().visit(body)
Expand Down Expand Up @@ -242,7 +243,7 @@ def process_kernel(self, routine, item, successors, targets, exclude_arrays):
SCCBaseTransformation.resolve_vector_dimension(routine, loop_variable=v_index, bounds=bounds)

# for kernels we process the entire body
routine.body = self.process_body(routine.body, routine.symbol_map, item.trafo_data[self._key]['definitions'],
routine.body = self.process_body(routine.body, item.trafo_data[self._key]['definitions'],
successors, targets, exclude_arrays)


Expand Down Expand Up @@ -394,7 +395,7 @@ def process_body(self, body, block_index, targets, exclude_arrays):
# have been parsed
if getattr(var, 'shape', None):
decl_rank = len(var.shape)

if local_rank == decl_rank - 1:
dimensions = getattr(var, 'dimensions', None) or ((RangeIndex((None, None)),) * (decl_rank - 1))
vmap.update({var: var.clone(dimensions=dimensions + as_tuple(block_index))})
Expand Down

0 comments on commit cd8eed3

Please sign in to comment.