diff --git a/skdecide/hub/solver/ars/ars.py b/skdecide/hub/solver/ars/ars.py index fcca08b002..c781340360 100755 --- a/skdecide/hub/solver/ars/ars.py +++ b/skdecide/hub/solver/ars/ars.py @@ -74,6 +74,8 @@ def flatten(c): class AugmentedRandomSearch(Solver, Policies, Restorable): + """Augmented Random Search solver.""" + T_domain = D def __init__( @@ -85,7 +87,22 @@ def __init__( learning_rate=0.02, policy_noise=0.03, reward_maximization=True, + callback: Callable[[AugmentedRandomSearch], bool] = lambda solver: False, ) -> None: + """ + + # Parameters + n_epochs + epoch_size + directions + top_directions + learning_rate + policy_noise + reward_maximization + callback: function called at each solver epoch. If returning true, the solve process stops. + + """ + self.callback = callback self.env = None self.n_epochs = n_epochs self.learning_rate = learning_rate @@ -216,10 +233,16 @@ def _solve(self, domain_factory: Callable[[], D]) -> None: self.update_policy(rollouts, sigma_r) # Printing the final reward of the policy after the update - reward_evaluation = self.explore(normalizer) - print("Step:", step, "Reward:", reward_evaluation, "Policy", self.policy) + self.reward_evaluation = self.explore(normalizer) + print( + "Step:", step, "Reward:", self.reward_evaluation, "Policy", self.policy + ) + + # Stopping because of user's callback? + if self.callback(self): + break - print("Final Reward:", reward_evaluation, "Policy", self.policy) + print("Final Reward:", self.reward_evaluation, "Policy", self.policy) def _sample_action( self, observation: D.T_agent[D.T_observation] diff --git a/skdecide/hub/solver/cgp/cgp.py b/skdecide/hub/solver/cgp/cgp.py index bf69bdf305..322286e25e 100755 --- a/skdecide/hub/solver/cgp/cgp.py +++ b/skdecide/hub/solver/cgp/cgp.py @@ -21,8 +21,7 @@ SingleAgent, UnrestrictedActions, ) -from skdecide.builders.solver import DeterministicPolicies, Restorable -from skdecide.hub.space.gym import GymSpace +from skdecide.builders.solver import DeterministicPolicies from .pycgp.cgpes import CGP, CGPES, Evaluator from .pycgp.cgpfunctions import ( @@ -202,6 +201,8 @@ def denorm(vals, types): class CGPWrapper(Solver, DeterministicPolicies): + """Cartesian Genetic Programming solver.""" + T_domain = D def __init__( @@ -217,8 +218,26 @@ def __init__( n_it=1000000, genome=None, verbose=True, + callback: Callable[[CGPWrapper], bool] = lambda solver: False, ): + """ + # Parameters + folder_name + library + col + row + nb_ind + mutation_rate_nodes + mutation_rate_outputs + n_cpus + n_it + genome + verbose + callback: function called at each solver iteration. If returning true, the solve process stops. + + """ + self.callback = callback if library is None: library = self._get_default_function_lib() @@ -296,21 +315,23 @@ def _solve(self, domain_factory: Callable[[], D]) -> None: print(cgpFather.genome) es = CGPES( - self._nb_ind, - self._mutation_rate_nodes, - self._mutation_rate_outputs, - cgpFather, - evaluator, - self._folder_name, - self._n_cpus, + num_offsprings=self._nb_ind, + mutation_rate_nodes=self._mutation_rate_nodes, + mutation_rate_outputs=self._mutation_rate_outputs, + father=cgpFather, + evaluator=evaluator, + folder=self._folder_name, + num_cpus=self._n_cpus, verbose=self._verbose, + callback=self.callback, + cgpwrapper=self, ) - es.run(self._n_it) - self._domain = domain self._es = es self._evaluator = evaluator + es.run(self._n_it) + def _get_next_action( self, observation: D.T_agent[D.T_observation] ) -> D.T_agent[D.T_concurrency[D.T_event]]: diff --git a/skdecide/hub/solver/cgp/pycgp/cgpes.py b/skdecide/hub/solver/cgp/pycgp/cgpes.py index a8c3312314..f5a0a1d96b 100644 --- a/skdecide/hub/solver/cgp/pycgp/cgpes.py +++ b/skdecide/hub/solver/cgp/pycgp/cgpes.py @@ -2,7 +2,10 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations # allow using CGPWrapper in annotations + import os +from typing import TYPE_CHECKING, Callable import numpy as np from joblib import Parallel, delayed @@ -10,6 +13,9 @@ from .cgp import CGP from .evaluator import Evaluator +if TYPE_CHECKING: # avoids circular imports + from ..cgp import CGPWrapper + class CGPES: def __init__( @@ -19,10 +25,14 @@ def __init__( mutation_rate_outputs, father, evaluator, + cgpwrapper: CGPWrapper, + callback: Callable[[CGPWrapper], bool], folder="genomes", num_cpus=1, verbose=True, ): + self.callback = callback + self.cgpwrapper = cgpwrapper self.num_offsprings = num_offsprings self.mutation_rate_nodes = mutation_rate_nodes self.mutation_rate_outputs = mutation_rate_outputs @@ -116,3 +126,6 @@ def offspring_eval_task(offspring_id): + str(self.current_fitness) + ".txt" ) + # Stopping because of user's callback? + if self.callback(self.cgpwrapper): + break diff --git a/skdecide/hub/solver/do_solver/do_solver_scheduling.py b/skdecide/hub/solver/do_solver/do_solver_scheduling.py index 553a71ba7f..4168d49313 100644 --- a/skdecide/hub/solver/do_solver/do_solver_scheduling.py +++ b/skdecide/hub/solver/do_solver/do_solver_scheduling.py @@ -166,7 +166,7 @@ def __init__( policy_method_params: PolicyMethodParams, method: SolvingMethod = SolvingMethod.PILE, dict_params: Optional[Dict[Any, Any]] = None, - callback: Optional[Callable[[DOSolver], bool]] = None, + callback: Callable[[DOSolver], bool] = lambda solver: False, ): self.callback = callback self.method = method @@ -206,10 +206,7 @@ def _solve(self, domain_factory: Callable[[], D]) -> None: self.dict_params[k] = params[k] # callbacks - if self.callback is None: - callbacks = [] - else: - callbacks = [_DOCallback(callback=self.callback, solver=self)] + callbacks = [_DOCallback(callback=self.callback, solver=self)] copy_dict_params = deepcopy(self.dict_params) if "callbacks" in copy_dict_params: callbacks = callbacks + copy_dict_params.pop("callbacks") diff --git a/skdecide/hub/solver/lazy_astar/lazy_astar.py b/skdecide/hub/solver/lazy_astar/lazy_astar.py index 13a8d9c637..b53a0f2ce7 100644 --- a/skdecide/hub/solver/lazy_astar/lazy_astar.py +++ b/skdecide/hub/solver/lazy_astar/lazy_astar.py @@ -38,6 +38,8 @@ class D( class LazyAstar(Solver, DeterministicPolicies, Utilities, FromAnyState): + """Lazy A* solver.""" + T_domain = D def __init__( @@ -48,8 +50,19 @@ def __init__( weight: float = 1.0, verbose: bool = False, render: bool = False, + callback: Callable[[LazyAstar], bool] = lambda solver: False, ) -> None: + """ + + # Parameters + heuristic + weight + verbose + render + callback: function called at each solver iteration. If returning true, the solve process stops. + """ + self.callback = callback self._heuristic = ( (lambda _, __: Value(cost=0.0)) if heuristic is None else heuristic ) @@ -136,18 +149,19 @@ def extender(node, label, explored): } # enqueued = {source: min([(0, self._weight * self._heuristic(source, target, initial_label[source]).cost) # for target in targets], key=lambda x: x[1]) for source in sources} - queue = [ + self.queue = [ (enqueued[source][1], next(c), source, 0, None, initial_label[source]) for source in sources ] # The explored dict is the CLOSED list. # It maps explored nodes to a pair of parent closest to the source and label of transition from parent. - explored = {} + self.explored = {} path = [] estim_total = 0.0 - while queue: + while self.queue and not self.callback(self): # Pop the smallest item from queue, i.e. with smallest f-value - estim_total, __, curnode, dist, parent, label = pop(queue) + estim_total, __, curnode, dist, parent, label = pop(self.queue) + if self._render: self._domain.render(curnode) if self._verbose: @@ -159,16 +173,16 @@ def extender(node, label, explored): path = [(parent, label), (curnode, None)] node = parent while node is not None: - (parent, label) = explored[node] + (parent, label) = self.explored[node] if parent is not None: path.insert(0, (parent, label)) node = parent break # return path, dist, enqueued[curnode][0], len(enqueued) - if curnode in explored: + if curnode in self.explored: continue - explored[curnode] = (parent, label) - for neighbor, cost, lbl in extender(curnode, label, explored): - if neighbor in explored: + self.explored[curnode] = (parent, label) + for neighbor, cost, lbl in extender(curnode, label, self.explored): + if neighbor in self.explored: continue ncost = dist + cost if neighbor in enqueued: @@ -184,7 +198,7 @@ def extender(node, label, explored): h = self._heuristic(self._domain, neighbor).cost enqueued[neighbor] = ncost, h push( - queue, + self.queue, ( ncost + (self._weight * h), next(c), diff --git a/skdecide/hub/solver/lrtastar/lrtastar.py b/skdecide/hub/solver/lrtastar/lrtastar.py index 86b2d93b81..5a4692a7ea 100644 --- a/skdecide/hub/solver/lrtastar/lrtastar.py +++ b/skdecide/hub/solver/lrtastar/lrtastar.py @@ -36,6 +36,8 @@ class D( class LRTAstar(Solver, DeterministicPolicies, Utilities, FromAnyState): + """Learning Real-Time A* solver.""" + T_domain = D def _get_next_action( @@ -60,7 +62,20 @@ def __init__( verbose: bool = False, max_iter=5000, max_depth=200, + callback: Callable[[LRTAstar], bool] = lambda solver: False, ) -> None: + """ + + # Parameters + heuristic + weight + verbose + max_iter + max_depth + callback: function called at each solver iteration. If returning true, the solve process stops. + + """ + self.callback = callback self._heuristic = ( (lambda _, __: Value(cost=0.0)) if heuristic is None else heuristic ) @@ -113,7 +128,7 @@ def _solve_from( iteration = 0 best_cost = float("inf") # best_path = None - while True: + while not self.callback(self): print(memory) dead_end, cumulated_cost, current_roll, list_action = self.doTrial(memory) if self._verbose: diff --git a/skdecide/hub/solver/maxent_irl/maxent_irl.py b/skdecide/hub/solver/maxent_irl/maxent_irl.py index c9555aab4b..6550d74933 100755 --- a/skdecide/hub/solver/maxent_irl/maxent_irl.py +++ b/skdecide/hub/solver/maxent_irl/maxent_irl.py @@ -22,6 +22,8 @@ class D(RLDomain): class MaxentIRL(Solver, Policies, Restorable): + """Maximum Entropy Inverse Reinforcement Learning solver.""" + T_domain = D def __init__( @@ -34,7 +36,23 @@ def __init__( theta_learning_rate=0.05, n_epochs=20000, expert_trajectories="maxent_expert_demo.npy", + callback: Callable[[MaxentIRL], bool] = lambda solver: False, ) -> None: + """ + + # Parameters + n_states + n_actions + one_feature + gamma + q_learning_rate + theta_learning_rate + n_epochs + expert_trajectories + callback: function called at each solver epoch. If returning true, the solve process stops. + + """ + self.callback = callback self.n_states = n_states self.feature_matrix = np.eye(self.n_states) self.n_actions = n_actions @@ -227,6 +245,10 @@ def _solve(self, domain_factory: Callable[[], D]) -> None: arr=self.q_table, ) + # Stopping because of user's callback? + if self.callback(self): + break + self.q_table = np.load( file=self.expert_trajectories[:-4] + "_maxent_q_table.npy" ) diff --git a/skdecide/hub/solver/pomcp/pomcp.py b/skdecide/hub/solver/pomcp/pomcp.py index b73f0c5ff0..34ec16972c 100644 --- a/skdecide/hub/solver/pomcp/pomcp.py +++ b/skdecide/hub/solver/pomcp/pomcp.py @@ -36,9 +36,27 @@ class D( class POMCP(Solver, DeterministicPolicies): + """Partially-Observable Monte Carlo Planning solver.""" + T_domain = D - def __init__(self, max_iterations=5000, max_depth=50, n_samples=5000) -> None: + def __init__( + self, + max_iterations=5000, + max_depth=50, + n_samples=5000, + callback: Callable[[POMCP], bool] = lambda solver: False, + ) -> None: + """ + + # Parameters + max_iterations + max_depth + n_samples + callback: function called at each solver iteration. If returning true, the solve process stops. + + """ + self.callback = callback self._max_iterations = max_iterations self._max_depth = max_depth self._n_samples = n_samples @@ -86,7 +104,9 @@ def _get_next_action( # Now, we can make a decision from the new belief state: iterations = 0 - while iterations < self._max_iterations: # or some other cut-off + while iterations < self._max_iterations and not self.callback( + self + ): # or some other cut-off # sample a state from the current belief state = random.choice(self._belief) self._tree_search(state, self._act_history, self._obs_history, 0) diff --git a/skdecide/hub/solver/stable_baselines/stable_baselines.py b/skdecide/hub/solver/stable_baselines/stable_baselines.py index 2a42145898..a61ec6ee7a 100644 --- a/skdecide/hub/solver/stable_baselines/stable_baselines.py +++ b/skdecide/hub/solver/stable_baselines/stable_baselines.py @@ -7,7 +7,9 @@ from typing import Any, Callable, Dict, Optional, Type, Union from stable_baselines3.common.base_class import BaseAlgorithm +from stable_baselines3.common.callbacks import BaseCallback, ConvertCallback from stable_baselines3.common.policies import BasePolicy +from stable_baselines3.common.type_aliases import MaybeCallback from stable_baselines3.common.vec_env import DummyVecEnv from skdecide import Domain, Solver @@ -40,6 +42,7 @@ def __init__( algo_class: Type[BaseAlgorithm], baselines_policy: Union[str, Type[BasePolicy]], learn_config: Optional[Dict[str, Any]] = None, + callback: Callable[[StableBaseline], bool] = lambda solver: False, **kwargs: Any, ) -> None: """Initialize StableBaselines. @@ -47,11 +50,15 @@ def __init__( # Parameters algo_class: The class of Baselines solver (stable_baselines3) to wrap. baselines_policy: The class of Baselines policy network (stable_baselines3.common.policies or str) to use. + learn_config: the kwargs passed to sb3 algo's `learn()` method + callback: function called at each solver iteration. If returning true, the solve process stops. + """ self._algo_class = algo_class self._baselines_policy = baselines_policy self._learn_config = learn_config if learn_config is not None else {} self._algo_kwargs = kwargs + self.callback = callback @classmethod def _check_domain_additional(cls, domain: Domain) -> bool: @@ -74,7 +81,20 @@ def _solve(self, domain_factory: Callable[[], D]) -> None: self._baselines_policy, env, **self._algo_kwargs ) self._init_algo(domain) - self._algo.learn(**self._learn_config) + + # Add user callback to list of callbacks in learn_config + learn_config = dict(self._learn_config) + callbacks_list: MaybeCallback = learn_config.get("callback", []) + if callbacks_list is None: + callbacks_list = [] + if isinstance(callbacks_list, BaseCallback): + callbacks_list = [callbacks_list] + elif not isinstance(callbacks_list, list): + callbacks_list = [ConvertCallback(callbacks_list)] + callbacks_list.append(Sb3Callback(callback=self.callback, solver=self)) + learn_config["callback"] = callbacks_list + + self._algo.learn(**learn_config) def _sample_action( self, observation: D.T_agent[D.T_observation] @@ -105,3 +125,15 @@ def _init_algo(self, domain: D): def get_policy(self) -> BasePolicy: """Return the computed policy.""" return self._algo.policy + + +class Sb3Callback(BaseCallback): + def __init__( + self, callback: Callable[[StableBaseline], bool], solver: StableBaseline + ): + super().__init__() + self.solver = solver + self.callback = callback + + def _on_step(self) -> bool: + return not self.callback(self.solver) diff --git a/tests/solvers/python/test_python_solvers.py b/tests/solvers/python/test_python_solvers.py index 509b77a958..f4f6c9ceba 100644 --- a/tests/solvers/python/test_python_solvers.py +++ b/tests/solvers/python/test_python_solvers.py @@ -4,6 +4,8 @@ from __future__ import annotations +import inspect +import logging from copy import deepcopy from enum import Enum from typing import NamedTuple, Optional @@ -17,6 +19,7 @@ from skdecide.hub.space.gym import EnumSpace, MultiDiscreteSpace from skdecide.utils import load_registered_solver +logger = logging.getLogger(__name__) # Must be defined outside the grid_domain() fixture # so that parallel domains can pickle it @@ -170,3 +173,50 @@ def test_solve_python(solver_python): assert solver_type.check_domain(dom) and ( (not solver_python["optimal"]) or (cost == 18 and len(plan) == 18) ) + + +class MyCallback: + """Callback for testing. + + - displays iteration number + - stops after max iteration reached + - check classes of domain and solver + + """ + + def __init__(self, solver_cls, max_iter=2): + self.solver_cls = solver_cls + self.max_iter = max_iter + self.iter = 0 + + def __call__(self, solver, *args): + self.iter += 1 + logger.warning(f"End of iteration #{self.iter}.") + assert isinstance(solver, self.solver_cls) + stopping = self.iter >= self.max_iter + return stopping + + +def test_solve_python_with_cb(solver_python, caplog): + solver_type = load_registered_solver(solver_python["entry"]) + if "callback" not in inspect.signature(solver_type.__init__).parameters: + pytest.skip( + f"Solver {solver_python['entry']} is not yet implementing callbacks." + ) + + dom = GridDomain() + + solver_args = deepcopy(solver_python["config"]) + if solver_python["entry"] == "StableBaseline": + solver_args["algo_class"] = PPO + elif solver_python["entry"] == "RayRLlib": + solver_args["algo_class"] = DQN + # Adding the callback + solver_args["callback"] = MyCallback(solver_cls=solver_type) + + with solver_type(**solver_args) as slv: + GridDomain.solve_with(slv) + + # Check that 2 iterations only were done and messages logged by callback + assert "End of iteration #2" in caplog.text + assert "End of iteration #3" not in caplog.text