From 8c3091d6aa8ee44425a2ebcd6572c678a73298dd Mon Sep 17 00:00:00 2001 From: Ed Kellett Date: Fri, 6 May 2016 18:09:24 +0100 Subject: [PATCH] Add (very debug) selector hijacking API --- fireplace/dsl/__init__.py | 1 + fireplace/dsl/hijack.py | 36 ++++++++++++++++++++++++++++++++++++ tests/test_dsl.py | 10 ++++++++++ 3 files changed, 47 insertions(+) create mode 100644 fireplace/dsl/hijack.py diff --git a/fireplace/dsl/__init__.py b/fireplace/dsl/__init__.py index 16f815641..a7c43c37a 100644 --- a/fireplace/dsl/__init__.py +++ b/fireplace/dsl/__init__.py @@ -1,5 +1,6 @@ from .copy import * from .evaluator import * +from .hijack import * from .lazynum import * from .random_picker import * from .selector import * diff --git a/fireplace/dsl/hijack.py b/fireplace/dsl/hijack.py new file mode 100644 index 000000000..6dfbaac8d --- /dev/null +++ b/fireplace/dsl/hijack.py @@ -0,0 +1,36 @@ +from contextlib import contextmanager +from .selector import Selector + + +class HijackedSelector(Selector): + def __init__(self, *a, **kw): + raise NotImplementedError + + def eval(self, entities, source): + return self._hijack_.eval(entities, source) + + +def hijack(victim, replace): + if victim.__class__ is not HijackedSelector: + victim._truth_ = victim.__class__ + victim.__class__ = HijackedSelector + victim._hijack_ = replace + + +def unhijack(victim): + try: + victim.__class__ = victim._truth_ + except AttributeError as e: + raise ValueError("not a hijacked selector") from e + + +@contextmanager +def hijacked(victim, replace): + if not isinstance(victim, Selector): + raise TypeError("not a selector: %r" % victim) + prev = victim.__class__ + try: + hijack(victim, replace) + yield + finally: + victim.__class__ = prev diff --git a/tests/test_dsl.py b/tests/test_dsl.py index 5bdd20b51..72e7be5cd 100644 --- a/tests/test_dsl.py +++ b/tests/test_dsl.py @@ -1,6 +1,8 @@ #!/usr/bin/env python +import pytest from utils import * from fireplace.dsl import * +from fireplace.exceptions import * from fireplace.card import Card @@ -189,3 +191,11 @@ def test_positional_selectors(): assert len(adjacent) == 2 assert adjacent[0] is wisp2 assert adjacent[1] is wisp3 + + +def test_hijack(): + game = prepare_game() + vial = game.player1.give("LOEA16_8") + with hijacked(RANDOM_ENEMY_MINION, FRIENDLY_HERO): + with pytest.raises(GameOver): + vial.play()