From 45a04de90a0ae38f886cc16e2efe552a1c63052a Mon Sep 17 00:00:00 2001 From: Olivier Delalleau <507137+odelalleau@users.noreply.github.com> Date: Tue, 8 Aug 2023 11:50:38 -0400 Subject: [PATCH] Properly escape interpolation-like strings in resolved configs Fixes #1112 Fixes #1081 --- omegaconf/_impl.py | 5 +++-- omegaconf/_utils.py | 37 +++++++++++++++++++++++++++++++++++++ 2 files changed, 40 insertions(+), 2 deletions(-) diff --git a/omegaconf/_impl.py b/omegaconf/_impl.py index 49be30329..a4e57ebb4 100644 --- a/omegaconf/_impl.py +++ b/omegaconf/_impl.py @@ -10,6 +10,7 @@ _get_value, is_primitive_container, is_structured_config, + maybe_escape, ) @@ -33,7 +34,7 @@ def _resolve_container_value(cfg: Container, key: Any) -> None: if isinstance(resolved, Container) and isinstance(node, ValueNode): cfg[key] = resolved else: - node._set_value(_get_value(resolved)) + node._set_value(maybe_escape(_get_value(resolved))) else: _resolve(node) @@ -46,7 +47,7 @@ def _resolve(cfg: Node) -> Node: except InterpolationToMissingValueError: cfg._set_value(MISSING) else: - cfg._set_value(resolved._value()) + cfg._set_value(maybe_escape(resolved._value())) if isinstance(cfg, DictConfig): for k in cfg.keys(): diff --git a/omegaconf/_utils.py b/omegaconf/_utils.py index 3452f48ca..388f8cfaa 100644 --- a/omegaconf/_utils.py +++ b/omegaconf/_utils.py @@ -683,6 +683,43 @@ def is_primitive_container(obj: Any) -> bool: return is_primitive_list(obj) or is_primitive_dict(obj) +def maybe_escape(value: Any) -> Any: + """Escape interpolation strings and return other values unchanged. + + When the input value is an interpolation string, the returned value is such that + it yields the original input string when resolved. + """ + if not isinstance(value, str) or not _is_interpolation_string( + value, strict_interpolation_validation=False + ): + return value + start = 0 + tokens = [] + while True: + # Find next ${ that needs escaping. + first_inter = value.find("${", start) + if first_inter < 0: + tokens.append(value[start:]) # ensure we keep the end of the string + break + # Any backslash that comes before ${ will need to be escaped as well. + count_esc = 0 + while ( + first_inter - count_esc - 1 >= 0 + and value[first_inter - count_esc - 1] == "\\" + ): + count_esc += 1 + tokens += [ + # Characters that need not be changed. + value[start : first_inter - count_esc], + # Escaped backslashes before the interpolation. + "\\" * (count_esc * 2), + # Escaped interpolation. + "\\${", + ] + start = first_inter + 2 + return "".join(tokens) + + def get_list_element_type(ref_type: Optional[Type[Any]]) -> Any: args = getattr(ref_type, "__args__", None) if ref_type is not List and args is not None and args[0]: