Skip to content

Commit

Permalink
Properly escape interpolation-like strings in resolved configs
Browse files Browse the repository at this point in the history
Fixes #1112
Fixes #1081
  • Loading branch information
odelalleau committed Aug 8, 2023
1 parent b155a60 commit 45a04de
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 2 deletions.
5 changes: 3 additions & 2 deletions omegaconf/_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
_get_value,
is_primitive_container,
is_structured_config,
maybe_escape,
)


Expand All @@ -33,7 +34,7 @@ def _resolve_container_value(cfg: Container, key: Any) -> None:
if isinstance(resolved, Container) and isinstance(node, ValueNode):
cfg[key] = resolved
else:
node._set_value(_get_value(resolved))
node._set_value(maybe_escape(_get_value(resolved)))
else:
_resolve(node)

Expand All @@ -46,7 +47,7 @@ def _resolve(cfg: Node) -> Node:
except InterpolationToMissingValueError:
cfg._set_value(MISSING)
else:
cfg._set_value(resolved._value())
cfg._set_value(maybe_escape(resolved._value()))

if isinstance(cfg, DictConfig):
for k in cfg.keys():
Expand Down
37 changes: 37 additions & 0 deletions omegaconf/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -683,6 +683,43 @@ def is_primitive_container(obj: Any) -> bool:
return is_primitive_list(obj) or is_primitive_dict(obj)


def maybe_escape(value: Any) -> Any:
"""Escape interpolation strings and return other values unchanged.
When the input value is an interpolation string, the returned value is such that
it yields the original input string when resolved.
"""
if not isinstance(value, str) or not _is_interpolation_string(
value, strict_interpolation_validation=False
):
return value
start = 0
tokens = []
while True:
# Find next ${ that needs escaping.
first_inter = value.find("${", start)
if first_inter < 0:
tokens.append(value[start:]) # ensure we keep the end of the string
break
# Any backslash that comes before ${ will need to be escaped as well.
count_esc = 0
while (
first_inter - count_esc - 1 >= 0
and value[first_inter - count_esc - 1] == "\\"
):
count_esc += 1
tokens += [
# Characters that need not be changed.
value[start : first_inter - count_esc],
# Escaped backslashes before the interpolation.
"\\" * (count_esc * 2),
# Escaped interpolation.
"\\${",
]
start = first_inter + 2
return "".join(tokens)


def get_list_element_type(ref_type: Optional[Type[Any]]) -> Any:
args = getattr(ref_type, "__args__", None)
if ref_type is not List and args is not None and args[0]:
Expand Down

0 comments on commit 45a04de

Please sign in to comment.