Skip to content

Commit

Permalink
Fixed variable naming consistency in planner
Browse files Browse the repository at this point in the history
  • Loading branch information
mike-gimelfarb committed Dec 19, 2024
1 parent 353bfdf commit 1c5edde
Showing 1 changed file with 18 additions and 18 deletions.
36 changes: 18 additions & 18 deletions pyRDDLGym_jax/core/planner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1161,8 +1161,8 @@ def __init__(self, rddl: RDDLLiftedModel,
optimizer: Callable[..., optax.GradientTransformation]=optax.rmsprop,
optimizer_kwargs: Optional[Kwargs]=None,
clip_grad: Optional[float]=None,
line_search_params: Optional[Kwargs]=None,
noise_params: Optional[Kwargs]=None,
line_search_kwargs: Optional[Kwargs]=None,
noise_kwargs: Optional[Kwargs]=None,
logic: Logic=FuzzyLogic(),
use_symlog_reward: bool=False,
utility: Union[Callable[[jnp.ndarray], float], str]='mean',
Expand Down Expand Up @@ -1190,9 +1190,9 @@ def __init__(self, rddl: RDDLLiftedModel,
:param optimizer_kwargs: a dictionary of parameters to pass to the SGD
factory (e.g. which parameters are controllable externally)
:param clip_grad: maximum magnitude of gradient updates
:param line_search_params: parameters to pass to optional line search
:param line_search_kwargs: parameters to pass to optional line search
method to scale learning rate
:param noise_params: parameters of optional gradient noise
:param noise_kwargs: parameters of optional gradient noise
:param logic: a subclass of Logic for mapping exact mathematical
operations to their differentiable counterparts
:param use_symlog_reward: whether to use the symlog transform on the
Expand Down Expand Up @@ -1224,13 +1224,13 @@ def __init__(self, rddl: RDDLLiftedModel,
action_bounds = {}
self._action_bounds = action_bounds
self.use64bit = use64bit
self._optimizer_name = optimizer
self.optimizer_name = optimizer
if optimizer_kwargs is None:
optimizer_kwargs = {'learning_rate': 0.1}
self._optimizer_kwargs = optimizer_kwargs
self.optimizer_kwargs = optimizer_kwargs
self.clip_grad = clip_grad
self.ls_params = line_search_params
self.noise_params = noise_params
self.line_search_kwargs = line_search_kwargs
self.noise_kwargs = noise_kwargs

# set optimizer
try:
Expand All @@ -1246,12 +1246,11 @@ def __init__(self, rddl: RDDLLiftedModel,
pipeline = []
if clip_grad is not None:
pipeline.append(optax.clip(clip_grad))
if noise_params is not None:
pipeline.append(optax.add_noise(**noise_params))
if noise_kwargs is not None:
pipeline.append(optax.add_noise(**noise_kwargs))
pipeline.append(optimizer)
self._use_ls = line_search_params is not None
if self._use_ls:
pipeline.append(optax.scale_by_zoom_linesearch(**line_search_params))
if line_search_kwargs is not None:
pipeline.append(optax.scale_by_zoom_linesearch(**line_search_kwargs))
self.optimizer = optax.chain(*pipeline)

# set utility
Expand Down Expand Up @@ -1330,11 +1329,11 @@ def __str__(self) -> str:
f' cpfs_no_gradient ={self.cpfs_without_grad}\n'
f'optimizer hyper-parameters:\n'
f' use_64_bit ={self.use64bit}\n'
f' optimizer ={self._optimizer_name}\n'
f' optimizer args ={self._optimizer_kwargs}\n'
f' optimizer ={self.optimizer_name}\n'
f' optimizer args ={self.optimizer_kwargs}\n'
f' clip_gradient ={self.clip_grad}\n'
f' line_search_params={self.ls_params}\n'
f' noise_params ={self.noise_params}\n'
f' line_search_kwargs={self.line_search_kwargs}\n'
f' noise_kwargs ={self.noise_kwargs}\n'
f' batch_size_train ={self.batch_size_train}\n'
f' batch_size_test ={self.batch_size_test}')
result += '\n' + str(self.plan)
Expand Down Expand Up @@ -1469,6 +1468,7 @@ def _jax_wrapped_init_policy(key, policy_hyperparams, subs):
def _jax_update(self, loss):
optimizer = self.optimizer
projection = self.plan.projection
use_ls = self.line_search_kwargs is not None

# calculate the plan gradient w.r.t. return loss and update optimizer
# also perform a projection step to satisfy constraints on actions
Expand All @@ -1481,7 +1481,7 @@ def _jax_wrapped_plan_update(key, policy_params, policy_hyperparams,
grad_fn = jax.value_and_grad(loss, argnums=1, has_aux=True)
(loss_val, (log, model_params)), grad = grad_fn(
key, policy_params, policy_hyperparams, subs, model_params)
if self._use_ls:
if use_ls:
updates, opt_state = optimizer.update(
grad, opt_state, params=policy_params,
value=loss_val, grad=grad, value_fn=_jax_wrapped_loss_swapped,
Expand Down

0 comments on commit 1c5edde

Please sign in to comment.