Skip to content

Commit

Permalink
zcbor.py: Some refactoring preparing for unordered maps
Browse files Browse the repository at this point in the history
Refactor checking for map elements.
Add the top_level modifier on member names.
Extend the ptr_result modifier so it can be used for unordered maps.
Some cleanup.

Signed-off-by: Øyvind Rønningstad <[email protected]>
  • Loading branch information
oyvindronningstad committed Nov 13, 2024
1 parent cf32d11 commit 1f1db89
Showing 1 changed file with 59 additions and 44 deletions.
103 changes: 59 additions & 44 deletions zcbor/zcbor.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,7 @@ def __init__(self, default_max_qty, my_types, my_control_groups, base_name=None,
# Key element. Only for children of "MAP" elements. self.key is of the
# same class as self.
self.key = None
self.is_key = False
# The element specified via.cbor or.cborseq(only for byte
# strings).self.cbor is of the same class as self.
self.cbor = None
Expand Down Expand Up @@ -353,7 +354,7 @@ def generate_base_name(self):
if self.type == "TSTR" and self.value is not None else None)
# Name an integer by its expected value:
or (f"{self.type.lower()}{abs(self.value)}"
if self.type in ["INT", "UINT", "NINT"] and self.value is not None else None)
if self.type in ["UINT", "NINT"] and self.value is not None else None)
# Name a type by its type name
or (next((key for key, value in self.my_types.items() if value == self), None))
# Name a control group by its name
Expand Down Expand Up @@ -510,6 +511,7 @@ def flatten(self, allow_multi=False):
self.value[0].label = self.label
if not self.value[0].key:
self.value[0].key = self.key
self.value[0].is_key = self.is_key
self.value[0].tags.extend(self.tags)
return self.value
elif allow_multi and self.type in ["GROUP"] and self.min_qty == 1 and self.max_qty == 1:
Expand Down Expand Up @@ -699,6 +701,7 @@ def set_key(self, key):
if key.type == "GROUP":
raise TypeError("A key cannot be a group because it might represent more than 1 type.")
self.key = key
key.is_key = True

def set_key_or_label(self, key_or_label):
"""Set the self.label OR self.key of this element.
Expand Down Expand Up @@ -939,16 +942,26 @@ def get_value(self, instr):
# Return the unparsed part of the string.
return instr.strip()

def elem_has_key(self):
def has_key(self):
"""For checking whether this element has a key (i.e. that it is a valid "MAP" child)
This must have some recursion since CDDL allows the key to be hidden
behind layers of indirection.
"""
return self.key is not None\
or (self.type == "OTHER" and self.my_types[self.value].elem_has_key())\
ret = self.key is not None\
or (self.type == "OTHER" and self.my_types[self.value].has_key())\
or (self.type in ["GROUP", "UNION"]
and (self.value and all(child.elem_has_key() for child in self.value)))
and (self.value and all(child.has_key() for child in self.value)))
return ret

def is_valid_map_elem(self):
"""For checking whether this element meets the conditions for being a valid map element.
This can be overridden by subclasses to further validate keys.
"""
if not self.has_key():
return False, f"Missing map key"
return True, ""

def post_validate(self):
"""Function for performing validations that must be done after all parsing is complete.
Expand All @@ -957,14 +970,12 @@ def post_validate(self):
"""
# Validation of this element.
if self.type in ["LIST", "MAP"]:
none_keys = [child for child in self.value if not child.elem_has_key()]
child_keys = [child for child in self.value if child not in none_keys]
if self.type == "MAP" and none_keys:
invalid_elems = [child for child in self.value if not child.is_valid_map_elem()[0]]
if self.type == "MAP" and invalid_elems:
raise TypeError(
"Map member(s) must have key: " + str(none_keys) + " pointing to "
+ str(
[self.my_types[elem.value] for elem in none_keys
if elem.type == "OTHER"]))
"Map member(s) are invalid:\n" + '\n'.join(
[f"{str(c)}: {c.is_valid_map_elem()[1]}" for c in invalid_elems]))
child_keys = [child for child in self.value if child not in invalid_elems]
if self.type == "LIST" and child_keys:
raise TypeError(
str(self) + linesep
Expand Down Expand Up @@ -1064,10 +1075,6 @@ def skip_condition(self):
return True
if self.type in ["LIST", "MAP", "GROUP"]:
return not self.repeated_multi_var_condition()
if self.type == "OTHER":
return ((not self.repeated_multi_var_condition())
and (not self.multi_var_condition())
and (self.single_func_impl_condition() or self in self.my_types.values()))
return False

def set_skipped(self, skipped):
Expand Down Expand Up @@ -1151,12 +1158,14 @@ def var_access(self):
return "NULL"
return self.access_append()

def val_access(self):
def val_access(self, top_level=False):
""""Path" to access this element's actual value variable."""
if self.is_unambiguous_repeated():
ret = "NULL"
elif self.skip_condition() or self.is_delegated_type():
ret = self.var_access()
elif top_level and not (self.type_def_condition() or self.repeated_type_def_condition()):
ret = self.var_access()
else:
ret = self.access_append(self.var_name())
return ret
Expand Down Expand Up @@ -2171,11 +2180,11 @@ def type_def(self):
if self.repeated_type_def_condition():
type_def_list = self.single_var_type(full=False)
if type_def_list:
ret_val.extend([(self.single_var_type(full=False), self.repeated_type_name())])
ret_val.extend([(type_def_list, self.repeated_type_name())])
if self.type_def_condition():
type_def_list = self.single_var_type()
if type_def_list:
ret_val.extend([(self.single_var_type(), self.type_name())])
ret_val.extend([(type_def_list, self.type_name())])
return ret_val

def type_def_bits(self):
Expand Down Expand Up @@ -2268,7 +2277,7 @@ def single_func_prim(self, access, union_int=None, ptr_result=False):
return (None, None)

if self.type == "OTHER":
return self.my_types[self.value].single_func(access, union_int)
return self.my_types[self.value].single_func(access, union_int, ptr_result=ptr_result)

func_name = self.single_func_prim_name(union_int, ptr_result=ptr_result)
if func_name is None:
Expand All @@ -2288,12 +2297,13 @@ def single_func_prim(self, access, union_int=None, ptr_result=False):

return (func_name, arg)

def single_func(self, access=None, union_int=None):
def single_func(self, access=None, union_int=None, ptr_result=False):
"""Return the function name and arguments to call to encode/decode this element."""
if self.single_func_impl_condition():
return (self.xcode_func_name(), deref_if_not_null(access or self.var_access()))
else:
return self.single_func_prim(access or self.val_access(), union_int)
return self.single_func_prim(access or self.val_access(), union_int,
ptr_result=ptr_result)

def repeated_single_func(self, ptr_result=False):
"""Return the function name and arguments to call to encode/decode the repeated
Expand All @@ -2308,6 +2318,7 @@ def has_backup(self):
return (self.cbor_var_condition() or self.type in ["LIST", "MAP", "UNION"])

def num_backups(self):
"""Calculate the number of state var backups needed for this element and all descendants."""
total = 0
if self.key:
total += self.key.num_backups()
Expand Down Expand Up @@ -2342,9 +2353,9 @@ def depends_on(self):

return max(ret_vals)

def xcode_single_func_prim(self, union_int=None):
def xcode_single_func_prim(self, union_int=None, top_level=False):
"""Make a string from the list returned by single_func_prim()"""
return xcode_statement(*self.single_func_prim(self.val_access(), union_int))
return xcode_statement(*self.single_func_prim(self.val_access(top_level), union_int))

def list_counts(self):
"""Recursively sum the total minimum and maximum element count for this element."""
Expand Down Expand Up @@ -2437,10 +2448,9 @@ def xcode_union(self):
and self.value[i - 1].simple_func_condition()):
child_values[i] = f"(zcbor_union_elem_code(state) && {child_values[i]})"

return "(%s && (int_res = (%s), %s, int_res))" \
% ("zcbor_union_start_code(state)",
f"{newl_ind}|| ".join(child_values),
"zcbor_union_end_code(state)")
child_code = f"{newl_ind}|| ".join(child_values)
return f"(zcbor_union_start_code(state) "\
+ f"&& (int_res = ({child_code}), zcbor_union_end_code(state), int_res))"
else:
return ternary_if_chain(
self.choice_var_access(),
Expand Down Expand Up @@ -2531,29 +2541,32 @@ def range_checks(self, access):

return range_checks

def repeated_xcode(self, union_int=None):
def repeated_xcode(self, union_int=None, top_level=False):
"""Return the full code needed to encode/decode this element.
Including children, key and cbor, excluding repetitions.
"""
val_union_int = union_int if not self.key else None # In maps, only pass union_int to key.
range_checks = self.range_checks(self.val_access())
range_checks = self.range_checks(self.val_access(top_level))

def do_xcode_single_func_prim(inner_union_int=None):
return self.xcode_single_func_prim(union_int=inner_union_int, top_level=top_level)
xcoder = {
"INT": self.xcode_single_func_prim,
"UINT": lambda: self.xcode_single_func_prim(val_union_int),
"NINT": lambda: self.xcode_single_func_prim(val_union_int),
"FLOAT": self.xcode_single_func_prim,
"INT": do_xcode_single_func_prim,
"UINT": lambda: do_xcode_single_func_prim(val_union_int),
"NINT": lambda: do_xcode_single_func_prim(val_union_int),
"FLOAT": do_xcode_single_func_prim,
"BSTR": self.xcode_bstr,
"TSTR": self.xcode_single_func_prim,
"BOOL": self.xcode_single_func_prim,
"NIL": self.xcode_single_func_prim,
"UNDEF": self.xcode_single_func_prim,
"ANY": self.xcode_single_func_prim,
"TSTR": do_xcode_single_func_prim,
"BOOL": do_xcode_single_func_prim,
"NIL": do_xcode_single_func_prim,
"UNDEF": do_xcode_single_func_prim,
"ANY": do_xcode_single_func_prim,
"LIST": self.xcode_list,
"MAP": self.xcode_list,
"GROUP": lambda: self.xcode_group(val_union_int),
"UNION": self.xcode_union,
"OTHER": lambda: self.xcode_single_func_prim(val_union_int),
"OTHER": lambda: do_xcode_single_func_prim(val_union_int),
}[self.type]
xcoders = []
if self.key:
Expand All @@ -2579,7 +2592,7 @@ def result_len(self):
else:
return "sizeof(%s)" % self.repeated_type_name()

def full_xcode(self, union_int=None):
def full_xcode(self, union_int=None, top_level=False):
"""Return the full code needed to encode/decode this element.
Including children, key, cbor, and repetitions.
Expand Down Expand Up @@ -2612,11 +2625,11 @@ def full_xcode(self, union_int=None):
xcode_args("*" + arg if arg != "NULL" and self.result_len() != "0" else arg),
self.result_len()))
else:
return self.repeated_xcode(union_int)
return self.repeated_xcode(union_int=union_int, top_level=top_level)

def xcode(self):
"""Return the body of the encoder/decoder function for this element."""
return self.full_xcode()
return self.full_xcode(top_level=True)

def xcoders(self):
"""Recursively return a list of the bodies of the encoder/decoder functions for
Expand All @@ -2637,7 +2650,9 @@ def xcoders(self):
yield xcoder
if self.repeated_single_func_impl_condition():
yield XcoderTuple(
self.repeated_xcode(), self.repeated_xcode_func_name(), self.repeated_type_name())
self.repeated_xcode(top_level=True),
self.repeated_xcode_func_name(),
self.repeated_type_name())
if (self.single_func_impl_condition()):
xcode_body = self.xcode()
yield XcoderTuple(xcode_body, self.xcode_func_name(), self.type_name())
Expand Down

0 comments on commit 1f1db89

Please sign in to comment.