diff --git a/pyRDDLGym_jax/core/tuning.py b/pyRDDLGym_jax/core/tuning.py index 5f3fab3..5ceb77a 100644 --- a/pyRDDLGym_jax/core/tuning.py +++ b/pyRDDLGym_jax/core/tuning.py @@ -8,6 +8,7 @@ import warnings warnings.filterwarnings("ignore") +from sklearn.gaussian_process.kernels import Matern, ConstantKernel from bayes_opt import BayesianOptimization from bayes_opt.acquisition import AcquisitionFunction, UpperConfidenceBound import jax @@ -71,8 +72,7 @@ def __init__(self, env: RDDLEnv, gp_iters: int=25, acquisition: Optional[AcquisitionFunction]=None, gp_init_kwargs: Optional[Kwargs]=None, - gp_params: Optional[Kwargs]=None, - gp_kernel_params: Optional[Kwargs]=None) -> None: + gp_params: Optional[Kwargs]=None) -> None: '''Creates a new instance for tuning hyper-parameters for Jax planners on the given RDDL domain and instance. @@ -98,8 +98,6 @@ def __init__(self, env: RDDLEnv, during initialization :param gp_params: additional parameters to feed to Bayesian optimizer after initialization optimization - :param gp_kernel_params: additional parameters to feed to the kernel of - the Bayesian optimizer ''' # objective parameters self.env = env @@ -122,16 +120,24 @@ def __init__(self, env: RDDLEnv, gp_init_kwargs = {} self.gp_init_kwargs = gp_init_kwargs if gp_params is None: - gp_params = {'n_restarts_optimizer': 20} + gp_params = {'n_restarts_optimizer': 25, + 'kernel': self.make_default_kernel()} self.gp_params = gp_params - if gp_kernel_params is None: - gp_kernel_params = {'length_scale_bounds': (0.2, 20.)} - self.gp_kernel_params = gp_kernel_params if acquisition is None: num_samples = self.gp_iters * self.num_workers acquisition = JaxParameterTuning.annealing_acquisition(num_samples) self.acquisition = acquisition + @staticmethod + def make_default_kernel(): + weight1 = ConstantKernel(1.0, (0.01, 100.0)) + weight2 = ConstantKernel(1.0, (0.01, 100.0)) + weight3 = ConstantKernel(1.0, (0.01, 100.0)) + kernel1 = Matern(length_scale=0.5, length_scale_bounds=(0.1, 0.5), nu=2.5) + kernel2 = Matern(length_scale=1.0, length_scale_bounds=(0.5, 1.0), nu=2.5) + kernel3 = Matern(length_scale=5.0, length_scale_bounds=(1.0, 5.0), nu=2.5) + return weight1 * kernel1 + weight2 * kernel2 + weight3 * kernel3 + def summarize_hyperparameters(self) -> None: hyper_params_table = [] for (_, param) in self.hyperparams_dict.items(): @@ -141,7 +147,6 @@ def summarize_hyperparameters(self) -> None: f' tuned_hyper_parameters =\n{hyper_params_table}\n' f' initialization_args ={self.gp_init_kwargs}\n' f' gp_params ={self.gp_params}\n' - f' gp_kernel_params ={self.gp_kernel_params}\n' f' tuning_iterations ={self.gp_iters}\n' f' tuning_timeout ={self.timeout_tuning}\n' f' tuning_batch_size ={self.num_workers}\n' @@ -335,8 +340,7 @@ def objective_function(params: ParameterValues, def tune_optimizer(self, optimizer: BayesianOptimization) -> None: '''Tunes the Bayesian optimization algorithm hyper-parameters.''' - print('\n' + f'The kernel length_scale was set to ' - f'{optimizer._gp.kernel_.length_scale}.') + print('\n' + f'The current kernel is {repr(optimizer._gp.kernel_)}.') def tune(self, key: int, log_file: str, show_dashboard: bool=False) -> ParameterValues: '''Tunes the hyper-parameters for Jax planner, returns the best found.''' @@ -409,11 +413,6 @@ def tune(self, key: int, log_file: str, show_dashboard: bool=False) -> Parameter for it in range(self.gp_iters): - # set the kernel parameters to user requested - optimizer._gp.kernel.set_params(**self.gp_kernel_params) - if hasattr(optimizer._gp, 'kernel_'): - optimizer._gp.kernel_.set_params(**self.gp_kernel_params) - # check if there is enough time left for another iteration elapsed = time.time() - start_time if elapsed >= self.timeout_tuning: