From 6724448db5e41d1965354eebd21affe78d941caa Mon Sep 17 00:00:00 2001 From: Kimiyuki Onaka Date: Sat, 29 Aug 2020 12:53:34 +0900 Subject: [PATCH 1/5] Add minimum_tree.py --- onlinejudge_template/analyzer/minimum_tree.py | 491 ++++++++++++++++++ onlinejudge_template/analyzer/node_util.py | 92 ++++ onlinejudge_template/analyzer/parser.py | 8 +- .../analyzer/simple_patterns.py | 4 +- onlinejudge_template/analyzer/simplify.py | 16 + tests/analyzer_minimum_tree.py | 166 ++++++ 6 files changed, 771 insertions(+), 6 deletions(-) create mode 100644 onlinejudge_template/analyzer/minimum_tree.py create mode 100644 onlinejudge_template/analyzer/node_util.py create mode 100644 tests/analyzer_minimum_tree.py diff --git a/onlinejudge_template/analyzer/minimum_tree.py b/onlinejudge_template/analyzer/minimum_tree.py new file mode 100644 index 0000000..215b874 --- /dev/null +++ b/onlinejudge_template/analyzer/minimum_tree.py @@ -0,0 +1,491 @@ +""" +the module to find minimum format trees from sample strings + +この module はサンプル文字列から直接 (つまり、フォーマット文字列を用いずに) フォーマット木を推測します。利用可能なサンプル文字列の個数がひとつしかない場合での利用が想定されています。 +フォーマット木に対する評価関数を固定しておき、すべてのサンプル文字列とマッチするフォーマット木の中で最小のものを求めるという形で実装されています。 + +たとえば +:: + + 3 + 1 2 + 3 4 1 2 + 2 4 1 + +および +:: + + 1 + 2 0 8 + +というサンプル文字列から +:: + + sequence([ + item("N"), + newline(), + loop(counter="i", size="N", sequence([ + item("K_i"), + loop(counter="j", size="K_i", + item("A", indices=("i", "j")) + ), + newline(), + ])), + ]) + +のようなフォーマット木 (:any:`FormatNode`) を作ります。 +この例の場合は +:: + + sequence([ + item("N"), + newline(), + loop(counter="i", size="N - 1", sequence([ + item("K_i"), + loop(counter="j", size="K_i - 1", + item("A", indices=("i", "j")) + ), + item("B", indices="i"), + newline(), + ])), + item("L"), + loop(counter="i", size="L", + item("C", indices="i") + ), + newline(), + ]) + +というフォーマット木もこれらのサンプルにマッチしますが、これは木の大きさが最小ではないので作られません。 + +内部のデータ構造は Haskell 風に書くと以下のような感じになります。 +`LoopNode` が持つふたつの `Int` は、ループの回数を表現する変数の de Bruijn index およびその変数を修正するための -1, 0, 1 のいずれかの数です。 +木の一部は構築途中である場合があります。 + +:: haskell + data Token + = IntToken Int + | StrngToken + | NewlineToken + + data Node m + = LoopNode Int Int (m (Node m)) (m (Node m)) + | IntNode (m (Node m)) + | StringNode (m (Node m)) + | NewlineNode (m (Node m)) + | EOFNode + + match :: Node Maybe -> [Token] -> Maybe MatchState + ... + + size :: Node Maybe -> Int + size (LoopNode _ delta body next) = 1 + abs delta + size body + size next + size (IntNode next) = 1 + size next + size (StringNode next) = 1 + size next + size (NewlineNode next) = 1 + size next + size EOFNode = 1 +""" + +import abc +import heapq +import itertools +import string +from typing import * + +import onlinejudge_template.analyzer.node_util as node_util +from onlinejudge_template.types import * + + +class _Token(abc.ABC): + row: int + column: int + + def __init__(self, *, row: int, column: int): + self.row = row + self.column = column + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(L{self.row}C{self.column})" + + +class _IntToken(_Token): + value: int + + def __init__(self, *, value: int, row: int, column: int): + super().__init__(row=row, column=column) + self.value = value + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(L{self.row}C{self.column}, value={self.value})" + + +class _StringToken(_Token): + value: str + + def __init__(self, *, value: str, row: int, column: int): + super().__init__(row=row, column=column) + self.value = value + + +class _NewlineToken(_Token): + pass + + +class _MatchState(NamedTuple): + tokens: List[_Token] + offset: int + env: List[int] + + +class _MatchStop(Exception): + def __init__(self, state: _MatchState): + self.state = state + + +class _Node(abc.ABC): + """_Node is a node similar to FormatNode but is easy to use for optimization. + """ + def __repr__(self) -> str: + return f"{self.__class__.__name__}()" + + @abc.abstractmethod + def get_tree_size(self) -> int: + raise NotImplementedError + + @abc.abstractmethod + def run_match(self, state: _MatchState) -> Optional[_MatchState]: + """ + :raises _MatchStop: + """ + raise NotImplementedError + + @abc.abstractmethod + def count_placeholder(self) -> int: + raise NotImplementedError + + @abc.abstractmethod + def get_replaced_first_placeholder(self, node: '_Node') -> Optional['_Node']: + raise NotImplementedError + + +class _PlaceholderNode(_Node): + def get_tree_size(self) -> int: + return 1 + + def run_match(self, state: _MatchState) -> Optional[_MatchState]: + raise _MatchStop(state) + + def count_placeholder(self) -> int: + return 1 + + def get_replaced_first_placeholder(self, node: _Node) -> _Node: + return node + + +class _EOFNode(_Node): + def get_tree_size(self) -> int: + return 1 + + def run_match(self, state: _MatchState) -> Optional[_MatchState]: + return state + + def count_placeholder(self) -> int: + return 0 + + def get_replaced_first_placeholder(self, node: _Node) -> Optional[_Node]: + return None + + +class _SimpleNonLeafNode(_Node): + next: _Node + + def __init__(self, *, next: _Node): + self.next = next + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(next={self.next})" + + def get_tree_size(self) -> int: + return 1 + self.next.get_tree_size() + + def count_placeholder(self) -> int: + return self.next.count_placeholder() + + def get_replaced_first_placeholder(self, node: _Node) -> Optional[_Node]: + next = self.next.get_replaced_first_placeholder(node) + if next is None: + return None + else: + return self.__class__(next=next) + + +class _IntNode(_SimpleNonLeafNode): + def run_match(self, state: _MatchState) -> Optional[_MatchState]: + assert 0 <= state.offset <= len(state.tokens) + if state.offset >= len(state.tokens): + return None + token = state.tokens[state.offset] + if not isinstance(token, _IntToken): + return None + state = _MatchState(tokens=state.tokens, offset=state.offset + 1, env=[token.value] + state.env) + return self.next.run_match(state) + + +class _StringNode(_SimpleNonLeafNode): + def run_match(self, state: _MatchState) -> Optional[_MatchState]: + assert 0 <= state.offset <= len(state.tokens) + if state.offset >= len(state.tokens): + return None + # An int is a str. `101` is an int but `1010100101010101010100111111101010101` may be a str. `10.1` is also a str. + if not isinstance(state.tokens[state.offset], _StringToken) and not isinstance(state.tokens[state.offset], _IntToken): + return None + state = _MatchState(tokens=state.tokens, offset=state.offset + 1, env=state.env) + return self.next.run_match(state) + + +class _NewlineNode(_SimpleNonLeafNode): + def run_match(self, state: _MatchState) -> Optional[_MatchState]: + assert 0 <= state.offset <= len(state.tokens) + if state.offset >= len(state.tokens): + return None + if not isinstance(state.tokens[state.offset], _NewlineToken): + return None + state = _MatchState(tokens=state.tokens, offset=state.offset + 1, env=state.env) + return self.next.run_match(state) + + +class _LoopNode(_Node): + index: int # de Bruijn index + delta: int + body: _Node + next: _Node + + def __init__(self, *, index: int, delta: int = 0, body: _Node, next: _Node): + assert delta in (-1, 0, 1) + self.index = index + self.delta = delta + self.body = body + self.next = next + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(index={self.index}, delta={self.delta}, body={self.body}, next={self.next})" + + def get_tree_size(self) -> int: + return 1 + abs(self.delta) + self.body.get_tree_size() + self.next.get_tree_size() + + def run_match(self, state: _MatchState) -> Optional[_MatchState]: + assert 0 <= self.index < len(state.env) + count = state.env[self.index] + self.delta + if count <= 0: + # loops of zero times cause some problems because some placeholders may be skipped + return None + + for _ in range(count): + result = self.body.run_match(state) + if result is None: + return None + state = _MatchState(tokens=state.tokens, offset=result.offset, env=state.env) # reset + return self.next.run_match(state) + + def count_placeholder(self) -> int: + return self.body.count_placeholder() + self.next.count_placeholder() + + def get_replaced_first_placeholder(self, node: _Node) -> Optional[_Node]: + body = self.body.get_replaced_first_placeholder(node) + if body is not None: + return _LoopNode(index=self.index, delta=self.delta, body=body, next=self.next) + else: + next = self.next.get_replaced_first_placeholder(node) + if next is not None: + return _LoopNode(index=self.index, delta=self.delta, body=self.body, next=next) + else: + return None + + +class _PriorityQueue: + _heap: List[Tuple[int, int, _Node]] = [] + _counter = itertools.count() + + def push(self, cost: int, node: _Node) -> None: + heapq.heappush(self._heap, (cost, next(self._counter), node)) + + def pop(self) -> _Node: + """pop() returns the item which has smallest cost value. + :raises IndexError: + """ + + _, _, node = heapq.heappop(self._heap) + return node + + def empty(self) -> bool: + return not self._heap + + +def tokenize_content(content: str) -> Iterator[_Token]: + # The int tokens are tokens which can be used as loop sizes. Only small integers satisfy this condition. + int_max = len(content.split()) + len(content.splitlines()) + 3 + + for y, line in enumerate(content.splitlines(keepends=True)): + words = line.split() + for x, word in enumerate(words): + try: + n = int(word) + except ValueError: + yield _StringToken(value=word, row=y, column=x) + else: + if 0 <= n <= int_max: + yield _IntToken(value=n, row=y, column=x) + else: + yield _StringToken(value=word, row=y, column=x) + if line.endswith('\n'): # including "\r\n" + yield _NewlineToken(row=y, column=len(words)) + + +def list_next_possible_node(states: List[_MatchState]) -> Iterator[_Node]: + # validate a set of states + assert states + for state in states: + assert 0 <= state.offset <= len(state.tokens) + env_size = len(states[0].env) + assert all([len(state.env) == env_size for state in states]) + + # EOF + yield _EOFNode() + if all([state.offset == len(state.tokens) for state in states]): + return + + # when some instances reach EOF but some instances don't + if any([state.offset == len(state.tokens) for state in states]): + return + + # when all next tokens are int tokens + if all([isinstance(state.tokens[state.offset], _IntToken) for state in states]): + yield _IntNode(next=_PlaceholderNode()) + for i in range(env_size): + for delta in (-1, 0, 1): + if all([0 <= state.env[i] + delta for state in states]): + yield _LoopNode(index=i, delta=delta, body=_IntNode(next=_PlaceholderNode()), next=_PlaceholderNode()) + return + + # when all next tokens are string tokens + if all([isinstance(state.tokens[state.offset], _StringToken) or isinstance(state.tokens[state.offset], _IntToken) for state in states]): + yield _StringNode(next=_PlaceholderNode()) + for i in range(env_size): + for delta in (-1, 0, 1): + if all([0 <= state.env[i] + delta for state in states]): + yield _LoopNode(index=i, delta=delta, body=_StringNode(next=_PlaceholderNode()), next=_PlaceholderNode()) + return + + # when all next tokens are newline tokens + if all([isinstance(state.tokens[state.offset], _NewlineToken) for state in states]): + yield _NewlineNode(next=_PlaceholderNode()) + # don't yield loop node here + return + + return + + +def _construct_minimum_input_format_internal_tree(*, instances: List[List[_Token]]) -> Optional[_Node]: + # init + que = _PriorityQueue() + initial_node = _PlaceholderNode() + que.push(initial_node.get_tree_size(), initial_node) + while not que.empty(): + # pop + cur = que.pop() + + # calc + states = [] + for instance in instances: + try: + state = cur.run_match(_MatchState(tokens=instance, offset=0, env=[])) + if state is None: + break + if state.offset != len(state.tokens): + break # matching finished before EOF + except _MatchStop as e: + state = e.state + states.append(state) + if len(states) != len(instances): + continue + if all([state.offset == len(state.tokens) for state in states]) and not cur.count_placeholder(): + return cur + + # push + for delta in list_next_possible_node(states): + nxt = cur.get_replaced_first_placeholder(delta) + assert nxt is not None + que.push(nxt.get_tree_size(), nxt) + + return None + + +class EnvItem(NamedTuple): + name: VarName + is_counter: bool + + +def _convert_to_format_node(node: _Node, *, env: List[EnvItem], used: Set[VarName]) -> FormatNode: + def get_fresh_name() -> VarName: + for var in map(VarName, string.ascii_letters): + if var not in used: + return var + else: + assert False # TODO: improve name assiging + + def list_indices(index: int) -> List[VarName]: + indices = [] + for item in reversed(env[index + 1:]): + if item.is_counter: + indices.append(item.name) + return indices + + if isinstance(node, _EOFNode): + return SequenceNode(items=[]) + + elif isinstance(node, _IntNode) or isinstance(node, _StringNode): + var = get_fresh_name() + delta: List[EnvItem] = [] + if isinstance(node, _IntNode): + delta = [EnvItem(var, False)] + indices = list_indices(-1) + + used.add(var) + return SequenceNode(items=[ + ItemNode(name=var, indices=indices), + _convert_to_format_node(node.next, env=delta + env, used=used), + ]) + + elif isinstance(node, _NewlineNode): + return SequenceNode(items=[ + NewlineNode(), + _convert_to_format_node(node.next, env=env, used=used), + ]) + + elif isinstance(node, _LoopNode): + size = Expr(env[node.index].name) + if list_indices(node.index): + size = Expr(str(size) + '_{' + ','.join(list_indices(node.index)) + '}') + var = get_fresh_name() + + used.add(var) + body = _convert_to_format_node(node.body, env=[EnvItem(var, True)] + env, used=used) + used.remove(var) + return SequenceNode(items=[ + LoopNode(size=size, name=var, body=body), + _convert_to_format_node(node.next, env=env, used=used), + ]) + + elif isinstance(node, _PlaceholderNode): + assert False + else: + assert False + + +def construct_minimum_input_format_tree(*, instances: List[str]) -> Optional[FormatNode]: + tokenized_instances = [list(tokenize_content(instance)) for instance in instances] + node = _construct_minimum_input_format_internal_tree(instances=tokenized_instances) + if node is None: + return None + format_node = _convert_to_format_node(node, env=[], used=set()) + format_node = node_util.rename_variable_nicely(format_node) + return node_util.remove_superfluous_sequence_nodes(format_node) diff --git a/onlinejudge_template/analyzer/node_util.py b/onlinejudge_template/analyzer/node_util.py new file mode 100644 index 0000000..4afba2f --- /dev/null +++ b/onlinejudge_template/analyzer/node_util.py @@ -0,0 +1,92 @@ +import string +from typing import * + +import onlinejudge_template.analyzer.simplify as simplify +from onlinejudge_template.types import * + + +def remove_superfluous_sequence_nodes(node: FormatNode) -> FormatNode: + if isinstance(node, ItemNode): + return node + elif isinstance(node, NewlineNode): + return node + elif isinstance(node, SequenceNode): + items = [] + for item in node.items: + item = remove_superfluous_sequence_nodes(item) + if isinstance(item, SequenceNode): + items.extend(item.items) + else: + items.append(item) + if len(items) == 1: + return items[0] + return SequenceNode(items=items) + elif isinstance(node, LoopNode): + return LoopNode(size=node.size, name=node.name, body=remove_superfluous_sequence_nodes(node.body)) + else: + assert False + + +def _get_nice_variable_name(*, used: Set[VarName]) -> VarName: + for c in map(VarName, 'abcdefgh' + 'mnopqrstuvwxyz'): + if c not in used: + return c + for c1 in string.ascii_uppercase: + for c2 in string.ascii_uppercase: + for c3 in string.ascii_uppercase: + s = VarName('a' + c1 + c2 + c3) + if s not in used: + return s + assert False + + +def _get_nice_counter_name(*, used: Set[VarName]) -> VarName: + for c in map(VarName, 'ijkl'): + if c not in used: + return c + for c1 in string.ascii_uppercase: + for c2 in string.ascii_uppercase: + for c3 in string.ascii_uppercase: + s = VarName('i' + c1 + c2 + c3) + if s not in used: + return s + assert False + + +def _rename_variable_nicely_dfs(node: FormatNode, *, replace: Dict[VarName, VarName], used: Set[VarName]) -> FormatNode: + if isinstance(node, ItemNode): + name = _get_nice_variable_name(used=used) + indices = [simplify.rename_variables_in_expr(index, replace=replace) for index in node.indices] + + assert node.name not in replace + replace[node.name] = name + used.add(name) + return ItemNode(name=name, indices=indices) + + elif isinstance(node, NewlineNode): + return NewlineNode() + + elif isinstance(node, SequenceNode): + items = [] + for item in node.items: + items.append(_rename_variable_nicely_dfs(item, replace=replace, used=used)) + return SequenceNode(items=items) + + elif isinstance(node, LoopNode): + name = _get_nice_counter_name(used=used) + size = simplify.rename_variables_in_expr(node.size, replace=replace) + + assert node.name not in replace + replace[node.name] = name + used.add(name) + body = _rename_variable_nicely_dfs(node.body, replace=replace, used=used) + used.remove(name) + replace.pop(node.name) + return LoopNode(size=size, name=name, body=body) + + else: + assert False + + +def rename_variable_nicely(node: FormatNode) -> FormatNode: + return _rename_variable_nicely_dfs(node, replace={}, used=set()) diff --git a/onlinejudge_template/analyzer/parser.py b/onlinejudge_template/analyzer/parser.py index 57c52e5..a32262c 100644 --- a/onlinejudge_template/analyzer/parser.py +++ b/onlinejudge_template/analyzer/parser.py @@ -336,7 +336,7 @@ def zip_nodes(a: FormatNode, b: FormatNode, *, name: VarName, size: Optional[Exp raise FormatStringParserError("semantics: unmatched dots pair: {} and {}".format(a, b)) -def exnted_loop_node(a: FormatNode, b: FormatNode, *, loop: LoopNode) -> Optional[FormatNode]: +def extend_loop_node(a: FormatNode, b: FormatNode, *, loop: LoopNode) -> Optional[FormatNode]: if isinstance(a, ItemNode) and isinstance(b, ItemNode): if a.name != b.name or len(a.indices) != len(b.indices): return None @@ -357,7 +357,7 @@ def exnted_loop_node(a: FormatNode, b: FormatNode, *, loop: LoopNode) -> Optiona return None items = [] for a_i, b_i in zip(a.items, b.items): - c_i = exnted_loop_node(a_i, b_i, loop=loop) + c_i = extend_loop_node(a_i, b_i, loop=loop) if c_i is None: return None items.append(c_i) @@ -366,7 +366,7 @@ def exnted_loop_node(a: FormatNode, b: FormatNode, *, loop: LoopNode) -> Optiona elif isinstance(a, LoopNode) and isinstance(b, LoopNode): if a.size != b.size or a.name != b.name: return None - c = exnted_loop_node(a.body, b.body, loop=loop) + c = extend_loop_node(a.body, b.body, loop=loop) if c is None: return None return LoopNode(size=a.size, name=a.name, body=c) @@ -405,7 +405,7 @@ def analyze_parsed_node(node: ParserNode) -> FormatNode: else: items_init = items[:-1] items_tail = items[-1] - extended_body = exnted_loop_node(items_tail, item.body, loop=item) + extended_body = extend_loop_node(items_tail, item.body, loop=item) if extended_body is not None: extended_loop: FormatNode = LoopNode(size=simplify(Expr(f"""{item.size} + 1""")), name=item.name, body=extended_body) items = items_init diff --git a/onlinejudge_template/analyzer/simple_patterns.py b/onlinejudge_template/analyzer/simple_patterns.py index 40908a5..ae01f91 100644 --- a/onlinejudge_template/analyzer/simple_patterns.py +++ b/onlinejudge_template/analyzer/simple_patterns.py @@ -1,7 +1,7 @@ """ -the module to guess format trees from sample strings +the module to guess simple format trees from sample strings -この module はサンプル文字列から直接 (つまり、フォーマット文字列を用いずに) フォーマット木を推測します。 +この module はサンプル文字列から直接 (つまり、フォーマット文字列を用いずに) 典型的なフォーマット木を推測します。利用可能なサンプル文字列の個数がひとつしかない場合での利用が想定されています。 単純なフォーマット木を列挙しておき、それらとのパターンマッチをすることによって実装されています。 たとえば diff --git a/onlinejudge_template/analyzer/simplify.py b/onlinejudge_template/analyzer/simplify.py index da12d93..c7187bb 100644 --- a/onlinejudge_template/analyzer/simplify.py +++ b/onlinejudge_template/analyzer/simplify.py @@ -15,6 +15,7 @@ import abc import fractions +import re from logging import getLogger from typing import * @@ -495,3 +496,18 @@ def parse_subscripted_variable(s: str) -> Tuple[str, List[str]]: if not isinstance(expr, _Variable): raise ExprParserError('not a subscripted variable: {}'.format(s)) return expr.name, list(map(_format, expr.args)) + + +def rename_variables_in_expr(expr: Expr, *, replace: Dict[VarName, VarName]) -> Expr: + """ + :raises ExprParserError: + """ + + pattern = r'[A-Za-z]+|[^A-Za-z]+' + s = [] + for c in re.findall(pattern, str(expr)): + if c.isalpha() and VarName(c) in replace: + s.append(str(replace[VarName(c)])) + else: + s.append(c) + return Expr(_format(_parse(''.join(s)))) diff --git a/tests/analyzer_minimum_tree.py b/tests/analyzer_minimum_tree.py new file mode 100644 index 0000000..cbc97d7 --- /dev/null +++ b/tests/analyzer_minimum_tree.py @@ -0,0 +1,166 @@ +import textwrap +import unittest + +import onlinejudge_template.analyzer.minimum_tree as analyzer +from onlinejudge_template.types import * + + +class TestMinimumTree(unittest.TestCase): + def test_simple(self) -> None: + instances = [ + textwrap.dedent("""\ + 3 + 1 2 + 3 4 1 2 + 2 4 1 + """), + textwrap.dedent("""\ + 1 + 2 0 8 + """), + ] + expected = SequenceNode(items=[ + ItemNode(name='a'), + NewlineNode(), + LoopNode(size='a', name='i', body=SequenceNode(items=[ + ItemNode(name='b', indices=['i']), + LoopNode(size='b_i', name='j', body=ItemNode(name='c', indices=['i', 'j'])), + NewlineNode(), + ])), + ]) + + actual = analyzer.construct_minimum_input_format_tree(instances=instances) + self.assertEqual(str(actual), str(expected)) + + def test_codeforces_1406_A(self) -> None: + """It has only one sample input with multiple cases. Each case is simple. + """ + # https://codeforces.com/contest/1406/problem/A + instances = [ + textwrap.dedent("""\ + 4 + 6 + 0 2 1 5 0 1 + 3 + 0 1 2 + 4 + 0 2 0 1 + 6 + 1 2 3 4 5 6 + """), + ] + expected = SequenceNode(items=[ + ItemNode(name='a'), + NewlineNode(), + LoopNode(size='a', name='i', body=SequenceNode(items=[ + ItemNode(name='b', indices=['i']), + NewlineNode(), + LoopNode(size='b_i', name='j', body=ItemNode(name='c', indices=['i', 'j'])), + NewlineNode(), + ])), + ]) + + actual = analyzer.construct_minimum_input_format_tree(instances=instances) + self.assertEqual(str(actual), str(expected)) + + def test_codeforces_1406_D(self) -> None: + """It has many separated sample cases. Each case is complicated. Also they have zeros and negative values. + """ + # https://codeforces.com/contest/1406/problem/D + instances = [ + textwrap.dedent("""\ + 4 + 2 -1 7 3 + 2 + 2 4 -3 + 3 4 2 + """), + textwrap.dedent("""\ + 6 + -9 -10 -9 -6 -5 4 + 3 + 2 6 -9 + 1 2 -10 + 4 6 -3 + """), + textwrap.dedent("""\ + 1 + 0 + 2 + 1 1 -1 + 1 1 -1 + """), + ] + expected = SequenceNode(items=[ + ItemNode(name='a'), + NewlineNode(), + LoopNode(size='a', name='i', body=ItemNode(name='b', indices=['i'])), + NewlineNode(), + ItemNode(name='c'), + NewlineNode(), + LoopNode(size='c', name='i', body=SequenceNode(items=[ + ItemNode(name='d', indices=['i']), + ItemNode(name='e', indices=['i']), + ItemNode(name='f', indices=['i']), + NewlineNode(), + ])), + ]) + + actual = analyzer.construct_minimum_input_format_tree(instances=instances) + self.assertEqual(str(actual), str(expected)) + + def test_atcoder_agc028_f(self) -> None: + """It has many separated sample cases. Each case has non-integers which sometimes looks like integers. + """ + # https://atcoder.jp/contests/agc028/tasks/agc028_f + instances = [ + textwrap.dedent("""\ + 2 + 11 + 11 + """), + textwrap.dedent("""\ + 4 + 1111 + 11#1 + 1#11 + 1111 + """), + textwrap.dedent("""\ + 10 + 76##63##3# + 8445669721 + 75#9542133 + 3#285##445 + 749632##89 + 2458##9515 + 5952578#77 + 1#3#44196# + 4355#99#1# + #298#63587 + """), + textwrap.dedent("""\ + 10 + 4177143673 + 7######### + 5#1716155# + 6#4#####5# + 2#3#597#6# + 6#9#8#3#5# + 5#2#899#9# + 1#6#####6# + 6#5359657# + 5######### + """), + ] + expected = SequenceNode(items=[ + ItemNode(name='a'), + NewlineNode(), + LoopNode(size='a', name='i', body=SequenceNode(items=[ + ItemNode(name='b', indices=['i']), + NewlineNode(), + ])), + ]) + + actual = analyzer.construct_minimum_input_format_tree(instances=instances) + self.assertEqual(str(actual), str(expected)) From b741c5c8b7fe4ff43186c9732515b82857af641e Mon Sep 17 00:00:00 2001 From: Kimiyuki Onaka Date: Tue, 15 Sep 2020 22:34:15 +0900 Subject: [PATCH 2/5] Add some tests for minimum_tree.py --- onlinejudge_template/analyzer/minimum_tree.py | 7 ++++- tests/analyzer_minimum_tree.py | 26 +++++++++++++++++++ 2 files changed, 32 insertions(+), 1 deletion(-) diff --git a/onlinejudge_template/analyzer/minimum_tree.py b/onlinejudge_template/analyzer/minimum_tree.py index 215b874..3d12352 100644 --- a/onlinejudge_template/analyzer/minimum_tree.py +++ b/onlinejudge_template/analyzer/minimum_tree.py @@ -384,7 +384,7 @@ def list_next_possible_node(states: List[_MatchState]) -> Iterator[_Node]: return -def _construct_minimum_input_format_internal_tree(*, instances: List[List[_Token]]) -> Optional[_Node]: +def _construct_minimum_input_format_internal_tree(*, instances: List[List[_Token]], limit: int = 10000) -> Optional[_Node]: # init que = _PriorityQueue() initial_node = _PlaceholderNode() @@ -416,6 +416,11 @@ def _construct_minimum_input_format_internal_tree(*, instances: List[List[_Token assert nxt is not None que.push(nxt.get_tree_size(), nxt) + # timeout. This function doesn't have good time complexity, so may take too long time. + limit -= 1 + if limit < 0: + break + return None diff --git a/tests/analyzer_minimum_tree.py b/tests/analyzer_minimum_tree.py index cbc97d7..33e93a3 100644 --- a/tests/analyzer_minimum_tree.py +++ b/tests/analyzer_minimum_tree.py @@ -32,6 +32,32 @@ def test_simple(self) -> None: actual = analyzer.construct_minimum_input_format_tree(instances=instances) self.assertEqual(str(actual), str(expected)) + def test_failure(self) -> None: + instances = [ + textwrap.dedent("""\ + a + """), + textwrap.dedent("""\ + b + c + """), + textwrap.dedent("""\ + d e + f + """), + ] + expected = None + + actual = analyzer.construct_minimum_input_format_tree(instances=instances) + self.assertEqual(str(actual), str(expected)) + + def test_too_slow(self) -> None: + instances = ['5 5 5 5 5\n' * 100] + expected = None + + actual = analyzer.construct_minimum_input_format_tree(instances=instances) + self.assertEqual(str(actual), str(expected)) + def test_codeforces_1406_A(self) -> None: """It has only one sample input with multiple cases. Each case is simple. """ From 54a8c52378a74457a834a7426726eb3f3e190e18 Mon Sep 17 00:00:00 2001 From: Kimiyuki Onaka Date: Tue, 15 Sep 2020 22:49:09 +0900 Subject: [PATCH 3/5] Use minimum_tree.py --- onlinejudge_template/analyzer/combined.py | 15 +++-- onlinejudge_template/analyzer/minimum_tree.py | 9 +++ .../analyzer/simple_patterns.py | 59 ++++--------------- onlinejudge_template/analyzer/typing.py | 4 +- 4 files changed, 32 insertions(+), 55 deletions(-) diff --git a/onlinejudge_template/analyzer/combined.py b/onlinejudge_template/analyzer/combined.py index ae4cef4..805cc1a 100644 --- a/onlinejudge_template/analyzer/combined.py +++ b/onlinejudge_template/analyzer/combined.py @@ -3,6 +3,7 @@ import onlinejudge_template.analyzer.constants import onlinejudge_template.analyzer.html +import onlinejudge_template.analyzer.minimum_tree import onlinejudge_template.analyzer.output_types import onlinejudge_template.analyzer.parser import onlinejudge_template.analyzer.simple_patterns @@ -63,8 +64,10 @@ def run(resources: AnalyzerResources) -> AnalyzerResult: elif topcoder_class_definition is not None: input_format = onlinejudge_template.analyzer.topcoder.convert_topcoder_class_definition_to_input_format(topcoder_class_definition) elif resources.sample_cases: - input_samples = [case.input for case in resources.sample_cases] + input_samples = [case.input.decode() for case in resources.sample_cases] input_format = onlinejudge_template.analyzer.simple_patterns.guess_format_with_pattern_matching(instances=input_samples) + if input_format is None: + input_format = onlinejudge_template.analyzer.minimum_tree.construct_minimum_input_format_tree(instances=input_samples) # list the variables for input input_variables: Optional[Dict[VarName, VarDecl]] = None @@ -78,7 +81,7 @@ def run(resources: AnalyzerResources) -> AnalyzerResult: logger.error('input analyzer failed: %s', e) if input_format is not None and input_variables is not None and resources.sample_cases: - input_samples = [case.input for case in resources.sample_cases] + input_samples = [case.input.decode() for case in resources.sample_cases] try: input_types = onlinejudge_template.analyzer.typing.infer_types_from_instances(input_format, variables=input_variables, instances=input_samples) input_variables = onlinejudge_template.analyzer.typing.update_variables_with_types(variables=input_variables, types=input_types) @@ -99,9 +102,13 @@ def run(resources: AnalyzerResources) -> AnalyzerResult: elif resources.sample_cases: if input_format is not None and input_variables is not None: output_format = onlinejudge_template.analyzer.simple_patterns.guess_output_format_with_pattern_matching_using_input_format(instances=resources.sample_cases, input_format=input_format, input_variables=input_variables) + if output_format is None: + output_format = onlinejudge_template.analyzer.minimum_tree.construct_minimum_output_format_tree_using_input_format(instances=resources.sample_cases, input_format=input_format, input_variables=input_variables) else: - output_samples = [case.output for case in resources.sample_cases] + output_samples = [case.output.decode() for case in resources.sample_cases] output_format = onlinejudge_template.analyzer.simple_patterns.guess_format_with_pattern_matching(instances=output_samples) + if output_format is None: + output_format = onlinejudge_template.analyzer.minimum_tree.construct_minimum_output_format_tree(instances=output_samples) # list the variables for output output_variables: Optional[Dict[VarName, VarDecl]] = None @@ -115,7 +122,7 @@ def run(resources: AnalyzerResources) -> AnalyzerResult: logger.error('output analyzer failed: %s', e) if output_format is not None and output_variables is not None and resources.sample_cases: - output_samples = [case.output for case in resources.sample_cases] + output_samples = [case.output.decode() for case in resources.sample_cases] try: output_types = onlinejudge_template.analyzer.typing.infer_types_from_instances(output_format, variables=output_variables, instances=output_samples) output_variables = onlinejudge_template.analyzer.typing.update_variables_with_types(variables=output_variables, types=output_types) diff --git a/onlinejudge_template/analyzer/minimum_tree.py b/onlinejudge_template/analyzer/minimum_tree.py index 3d12352..1a2650f 100644 --- a/onlinejudge_template/analyzer/minimum_tree.py +++ b/onlinejudge_template/analyzer/minimum_tree.py @@ -494,3 +494,12 @@ def construct_minimum_input_format_tree(*, instances: List[str]) -> Optional[For format_node = _convert_to_format_node(node, env=[], used=set()) format_node = node_util.rename_variable_nicely(format_node) return node_util.remove_superfluous_sequence_nodes(format_node) + + +def construct_minimum_output_format_tree(*, instances: List[str]) -> Optional[FormatNode]: + return construct_minimum_input_format_tree(instances=instances) + + +def construct_minimum_output_format_tree_using_input_format(*, instances: List[SampleCase], input_format: FormatNode, input_variables: Dict[VarName, VarDecl]) -> Optional[FormatNode]: + output_samples = [case.output.decode() for case in instances] + return construct_minimum_output_format_tree(instances=output_samples) diff --git a/onlinejudge_template/analyzer/simple_patterns.py b/onlinejudge_template/analyzer/simple_patterns.py index ae01f91..f3167f0 100644 --- a/onlinejudge_template/analyzer/simple_patterns.py +++ b/onlinejudge_template/analyzer/simple_patterns.py @@ -201,60 +201,21 @@ class SimplePatternMatchingError(AnalyzerError): _length_and_vertical_two_vector_pattern, ] - -def _make_tree_pattern_dfs(node: FormatNode) -> Tuple[FormatNode, bool]: - if isinstance(node, ItemNode): - return node, False - - elif isinstance(node, NewlineNode): - return node, False - - elif isinstance(node, SequenceNode): - items: List[FormatNode] = [] - any_replaced = False - for item in node.items: - item, replaced = _make_tree_pattern_dfs(item) - if replaced: - any_replaced = True - return SequenceNode(items=items), any_replaced - - elif isinstance(node, LoopNode): - assert node.size == 'n' - body, _ = _make_tree_pattern_dfs(node.body) - return LoopNode(name=node.name, size='n - 1', body=body), True - - else: - assert False - - -def _make_tree_patterns(patterns: List[FormatNode]) -> List[FormatNode]: - """_make_tree_patterns detects patterns which have the variable `n` and arrays with lentgh `n`, and replaces the length of arrays with `n - 1`. - """ - - tree_patterns = [] - for pattern in patterns: - pattern, replaced = _make_tree_pattern_dfs(pattern) - if replaced: - tree_patterns.append(pattern) - return tree_patterns +_all_patterns: List[FormatNode] = [ + *_simple_patterns, + *_vertical_simple_patterns, + *_one_vector_patterns, + *_one_vector_with_data_patterns, + *_two_vectors_patterns, +] @functools.lru_cache(maxsize=None) def list_all_patterns() -> List[Tuple[FormatNode, Dict[VarName, VarDecl]]]: """list_all_patterns lists all pre-defined petterns. """ - - patterns: List[FormatNode] = [ - *_simple_patterns, - *_vertical_simple_patterns, - *_one_vector_patterns, - *_one_vector_with_data_patterns, - *_two_vectors_patterns, - ] - all_patterns = patterns + _make_tree_patterns(patterns) - results: List[Tuple[FormatNode, Dict[VarName, VarDecl]]] = [] - for pattern in all_patterns: + for pattern in _all_patterns: try: variables = onlinejudge_template.analyzer.variables.list_declared_variables(pattern) results.append((pattern, variables)) @@ -322,7 +283,7 @@ def rename_variables_if_conflicts(node: FormatNode, *, env: Dict[VarName, VarDec return _rename_variables_if_conflicts_dfs(node, mapping={}, env=env) -def guess_format_with_pattern_matching(*, instances: List[bytes]) -> Optional[FormatNode]: +def guess_format_with_pattern_matching(*, instances: List[str]) -> Optional[FormatNode]: """guess_format_with_pattern_matching guesses a format tree from the strings which match with the format tree, i.e. sample cases. :param instances: are sample cases. @@ -335,7 +296,7 @@ def guess_format_with_pattern_matching(*, instances: List[bytes]) -> Optional[Fo pattern = rename_variables_if_conflicts(pattern, env={}) try: for data in instances: - match_format(pattern, data.decode(), variables=variables) + match_format(pattern, data, variables=variables) except FormatMatchError: pass else: diff --git a/onlinejudge_template/analyzer/typing.py b/onlinejudge_template/analyzer/typing.py index 3947943..aee07f7 100644 --- a/onlinejudge_template/analyzer/typing.py +++ b/onlinejudge_template/analyzer/typing.py @@ -113,7 +113,7 @@ def unify_var_types(t1: Dict[VarName, VarType], t2: Dict[VarName, VarType]) -> D return t3 -def infer_types_from_instances(node: FormatNode, *, variables: Dict[VarName, VarDecl], instances: List[bytes]) -> Dict[VarName, VarType]: +def infer_types_from_instances(node: FormatNode, *, variables: Dict[VarName, VarDecl], instances: List[str]) -> Dict[VarName, VarType]: """ :raises FormatMatchError: :raises TypingError: @@ -122,7 +122,7 @@ def infer_types_from_instances(node: FormatNode, *, variables: Dict[VarName, Var assert instances types: Optional[Dict[VarName, VarType]] = None for i, data in enumerate(instances): - values = match_format(node, data.decode(), variables=variables) + values = match_format(node, data, variables=variables) logger.debug("match result for %d-th data: %s", i, values) types2 = get_var_types_from_match_result(values, variables=variables) if types is None: From 6a7b65455bb1647984e8b2b642fbdca86fc1db14 Mon Sep 17 00:00:00 2001 From: Kimiyuki Onaka Date: Tue, 15 Sep 2020 23:22:53 +0900 Subject: [PATCH 4/5] Update typing.py to reduce errors --- onlinejudge_template/analyzer/typing.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/onlinejudge_template/analyzer/typing.py b/onlinejudge_template/analyzer/typing.py index aee07f7..ed398af 100644 --- a/onlinejudge_template/analyzer/typing.py +++ b/onlinejudge_template/analyzer/typing.py @@ -62,12 +62,20 @@ def get_var_type(value: Union[int, float, str]) -> VarType: assert False -def unify_types(t1: VarType, t2: VarType) -> Optional[VarType]: +def unify_types(t1: VarType, t2: VarType) -> VarType: if t1 == t2: return t1 - if set([t1, t2]) == set([VarType.Char, VarType.String]): + if t1 == VarType.String or t2 == VarType.String: return VarType.String - return None + if t1 == VarType.Char or t2 == VarType.Char: + return VarType.String + if set([t1, t2]) == set([VarType.IndexInt, VarType.ValueInt]): + return VarType.ValueInt + if set([t1, t2]) == set([VarType.IndexInt, VarType.Float]): + return VarType.Float + if set([t1, t2]) == set([VarType.ValueInt, VarType.Float]): + return VarType.Float + assert False def get_var_types_from_match_result(values: Dict[VarName, Dict[Tuple[int, ...], Union[int, float, str]]], *, variables: Dict[VarName, VarDecl]) -> Dict[VarName, VarType]: @@ -82,8 +90,6 @@ def get_var_types_from_match_result(values: Dict[VarName, Dict[Tuple[int, ...], t1 = ts.pop() t2 = ts.pop() t3 = unify_types(t1, t2) - if t3 is None: - raise TypingError(f"""failed to unify types: {t1} and {t2} for variable {name}""") ts.add(t3) if not ts: raise TypingError(f"""failed to infer type: {name} has no candidate types""") @@ -96,8 +102,6 @@ def get_var_types_from_match_result(values: Dict[VarName, Dict[Tuple[int, ...], for name, decl in variables.items(): if decl.type is not None: t = unify_types(types[name], decl.type) - if t is None: - raise TypingError(f"""failed to unify types: {types[name]} and {decl.type} for variable {name}""") types[name] = t return types @@ -107,8 +111,6 @@ def unify_var_types(t1: Dict[VarName, VarType], t2: Dict[VarName, VarType]) -> D t3: Dict[VarName, VarType] = {} for name in t1.keys(): t = unify_types(t1[name], t2[name]) - if t is None: - raise TypingError(f"""failed to unify types: {t1[name]} and {t2[name]} for variable {name}""") t3[name] = t return t3 @@ -145,8 +147,6 @@ def update_variables_with_types(*, variables: Dict[VarName, VarDecl], types: Dic t = types[name] else: t1 = unify_types(types[name], decl.type) - if t1 is None: - raise TypingError(f"""failed to unify types: {types[name]} and {decl.type} for variable {name}""") t = t1 updated[name] = VarDecl( type=t, From adc1b2944e9ab9146183a5b773c34fd0a6448e5b Mon Sep 17 00:00:00 2001 From: Kimiyuki Onaka Date: Tue, 15 Sep 2020 23:31:07 +0900 Subject: [PATCH 5/5] Use information of the input format to minimize the output format tree --- onlinejudge_template/analyzer/minimum_tree.py | 51 +++++++++++++++++-- onlinejudge_template/analyzer/node_util.py | 4 +- 2 files changed, 48 insertions(+), 7 deletions(-) diff --git a/onlinejudge_template/analyzer/minimum_tree.py b/onlinejudge_template/analyzer/minimum_tree.py index 1a2650f..c82dc00 100644 --- a/onlinejudge_template/analyzer/minimum_tree.py +++ b/onlinejudge_template/analyzer/minimum_tree.py @@ -92,6 +92,7 @@ from typing import * import onlinejudge_template.analyzer.node_util as node_util +from onlinejudge_template.analyzer.match import FormatMatchError, match_format from onlinejudge_template.types import * @@ -384,7 +385,7 @@ def list_next_possible_node(states: List[_MatchState]) -> Iterator[_Node]: return -def _construct_minimum_input_format_internal_tree(*, instances: List[List[_Token]], limit: int = 10000) -> Optional[_Node]: +def _construct_minimum_input_format_internal_tree(*, instances: List[List[_Token]], initial_env: Optional[List[List[int]]] = None, limit: int = 10000) -> Optional[_Node]: # init que = _PriorityQueue() initial_node = _PlaceholderNode() @@ -395,9 +396,13 @@ def _construct_minimum_input_format_internal_tree(*, instances: List[List[_Token # calc states = [] - for instance in instances: + for i, instance in enumerate(instances): + if initial_env is not None: + env = initial_env[i] + else: + env = [] try: - state = cur.run_match(_MatchState(tokens=instance, offset=0, env=[])) + state = cur.run_match(_MatchState(tokens=instance, offset=0, env=env)) if state is None: break if state.offset != len(state.tokens): @@ -501,5 +506,41 @@ def construct_minimum_output_format_tree(*, instances: List[str]) -> Optional[Fo def construct_minimum_output_format_tree_using_input_format(*, instances: List[SampleCase], input_format: FormatNode, input_variables: Dict[VarName, VarDecl]) -> Optional[FormatNode]: - output_samples = [case.output.decode() for case in instances] - return construct_minimum_output_format_tree(instances=output_samples) + # prepare environments + minimizer_env: List[List[int]] = [] + converter_env: List[EnvItem] = [] + converter_used: Set[VarName] = set() + try: + for i, data in enumerate(instances): + minimizer_env.append([]) + input_values = match_format(input_format, data.input.decode(), variables=input_variables) + for name in sorted(input_variables.keys()): + decl = input_variables[name] + if (decl.type == VarType.IndexInt or decl.type == VarType.ValueInt) and not decl.dims: + value = input_values[name][()] + assert isinstance(value, int) + minimizer_env[i].append(value) + if i == 0: + converter_env.append(EnvItem(name, False)) + if i == 0: + converter_used.add(name) + except FormatMatchError: + output_samples = [case.output.decode() for case in instances] + return construct_minimum_output_format_tree(instances=output_samples) + print('construct_minimum_output_format_tree_using_input_format') + print(input_variables) + print(minimizer_env) + print(converter_env) + print(converter_used) + + # construct the tree + tokenized_instances = [list(tokenize_content(instance.output.decode())) for instance in instances] + node = _construct_minimum_input_format_internal_tree(instances=tokenized_instances, initial_env=minimizer_env) + if node is None: + return None + + # make format node + format_node = _convert_to_format_node(node, env=converter_env, used=converter_used) + format_node = node_util.rename_variable_nicely(format_node, used=converter_used) + print(format_node) + return node_util.remove_superfluous_sequence_nodes(format_node) diff --git a/onlinejudge_template/analyzer/node_util.py b/onlinejudge_template/analyzer/node_util.py index 4afba2f..3115d2a 100644 --- a/onlinejudge_template/analyzer/node_util.py +++ b/onlinejudge_template/analyzer/node_util.py @@ -88,5 +88,5 @@ def _rename_variable_nicely_dfs(node: FormatNode, *, replace: Dict[VarName, VarN assert False -def rename_variable_nicely(node: FormatNode) -> FormatNode: - return _rename_variable_nicely_dfs(node, replace={}, used=set()) +def rename_variable_nicely(node: FormatNode, *, used: Optional[Set[VarName]] = None) -> FormatNode: + return _rename_variable_nicely_dfs(node, replace={}, used=used or set())