Skip to content

Commit

Permalink
Replace kernel with mixture kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
mike-gimelfarb committed Dec 21, 2024
1 parent 1f1de53 commit 5665d5d
Showing 1 changed file with 15 additions and 16 deletions.
31 changes: 15 additions & 16 deletions pyRDDLGym_jax/core/tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import warnings
warnings.filterwarnings("ignore")

from sklearn.gaussian_process.kernels import Matern, ConstantKernel
from bayes_opt import BayesianOptimization
from bayes_opt.acquisition import AcquisitionFunction, UpperConfidenceBound
import jax
Expand Down Expand Up @@ -71,8 +72,7 @@ def __init__(self, env: RDDLEnv,
gp_iters: int=25,
acquisition: Optional[AcquisitionFunction]=None,
gp_init_kwargs: Optional[Kwargs]=None,
gp_params: Optional[Kwargs]=None,
gp_kernel_params: Optional[Kwargs]=None) -> None:
gp_params: Optional[Kwargs]=None) -> None:
'''Creates a new instance for tuning hyper-parameters for Jax planners
on the given RDDL domain and instance.
Expand All @@ -98,8 +98,6 @@ def __init__(self, env: RDDLEnv,
during initialization
:param gp_params: additional parameters to feed to Bayesian optimizer
after initialization optimization
:param gp_kernel_params: additional parameters to feed to the kernel of
the Bayesian optimizer
'''
# objective parameters
self.env = env
Expand All @@ -122,16 +120,24 @@ def __init__(self, env: RDDLEnv,
gp_init_kwargs = {}
self.gp_init_kwargs = gp_init_kwargs
if gp_params is None:
gp_params = {'n_restarts_optimizer': 20}
gp_params = {'n_restarts_optimizer': 25,
'kernel': self.make_default_kernel()}
self.gp_params = gp_params
if gp_kernel_params is None:
gp_kernel_params = {'length_scale_bounds': (0.2, 20.)}
self.gp_kernel_params = gp_kernel_params
if acquisition is None:
num_samples = self.gp_iters * self.num_workers
acquisition = JaxParameterTuning.annealing_acquisition(num_samples)
self.acquisition = acquisition

@staticmethod
def make_default_kernel():
weight1 = ConstantKernel(1.0, (0.01, 100.0))
weight2 = ConstantKernel(1.0, (0.01, 100.0))
weight3 = ConstantKernel(1.0, (0.01, 100.0))
kernel1 = Matern(length_scale=0.5, length_scale_bounds=(0.1, 0.5), nu=2.5)
kernel2 = Matern(length_scale=1.0, length_scale_bounds=(0.5, 1.0), nu=2.5)
kernel3 = Matern(length_scale=5.0, length_scale_bounds=(1.0, 5.0), nu=2.5)
return weight1 * kernel1 + weight2 * kernel2 + weight3 * kernel3

def summarize_hyperparameters(self) -> None:
hyper_params_table = []
for (_, param) in self.hyperparams_dict.items():
Expand All @@ -141,7 +147,6 @@ def summarize_hyperparameters(self) -> None:
f' tuned_hyper_parameters =\n{hyper_params_table}\n'
f' initialization_args ={self.gp_init_kwargs}\n'
f' gp_params ={self.gp_params}\n'
f' gp_kernel_params ={self.gp_kernel_params}\n'
f' tuning_iterations ={self.gp_iters}\n'
f' tuning_timeout ={self.timeout_tuning}\n'
f' tuning_batch_size ={self.num_workers}\n'
Expand Down Expand Up @@ -335,8 +340,7 @@ def objective_function(params: ParameterValues,

def tune_optimizer(self, optimizer: BayesianOptimization) -> None:
'''Tunes the Bayesian optimization algorithm hyper-parameters.'''
print('\n' + f'The kernel length_scale was set to '
f'{optimizer._gp.kernel_.length_scale}.')
print('\n' + f'The current kernel is {repr(optimizer._gp.kernel_)}.')

def tune(self, key: int, log_file: str, show_dashboard: bool=False) -> ParameterValues:
'''Tunes the hyper-parameters for Jax planner, returns the best found.'''
Expand Down Expand Up @@ -409,11 +413,6 @@ def tune(self, key: int, log_file: str, show_dashboard: bool=False) -> Parameter

for it in range(self.gp_iters):

# set the kernel parameters to user requested
optimizer._gp.kernel.set_params(**self.gp_kernel_params)
if hasattr(optimizer._gp, 'kernel_'):
optimizer._gp.kernel_.set_params(**self.gp_kernel_params)

# check if there is enough time left for another iteration
elapsed = time.time() - start_time
if elapsed >= self.timeout_tuning:
Expand Down

0 comments on commit 5665d5d

Please sign in to comment.