Skip to content

Commit

Permalink
[RLlib] RLModule: Add TargetNetworkAPI and implement for APPO and S…
Browse files Browse the repository at this point in the history
…AC. (#46656)
  • Loading branch information
sven1977 authored Jul 19, 2024
1 parent ea452c9 commit 2a8b122
Show file tree
Hide file tree
Showing 18 changed files with 316 additions and 384 deletions.
46 changes: 39 additions & 7 deletions rllib/algorithms/appo/appo_learner.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
import abc
from typing import Any, Dict
from typing import Any, Dict, Optional

from ray.rllib.algorithms.appo.appo import APPOConfig
from ray.rllib.algorithms.impala.impala_learner import IMPALALearner
from ray.rllib.core.learner.learner import Learner
from ray.rllib.core.learner.utils import update_target_network
from ray.rllib.core.rl_module.apis.target_network_api import TargetNetworkAPI
from ray.rllib.core.rl_module.marl_module import MultiAgentRLModuleSpec
from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec
from ray.rllib.utils.annotations import override
from ray.rllib.utils.lambda_defaultdict import LambdaDefaultDict
from ray.rllib.utils.metrics import (
Expand All @@ -13,7 +17,7 @@
NUM_TARGET_UPDATES,
)
from ray.rllib.utils.schedules.scheduler import Scheduler
from ray.rllib.utils.typing import ModuleID
from ray.rllib.utils.typing import ModuleID, ShouldModuleBeUpdatedFn


class APPOLearner(IMPALALearner):
Expand All @@ -26,9 +30,13 @@ class APPOLearner(IMPALALearner):
def build(self):
super().build()

# Initially sync target networks (w/ tau=1.0 -> full overwrite).
# Make target networks.
self.module.foreach_module(
lambda mid, module: module.sync_target_networks(tau=1.0)
lambda mid, mod: (
mod.make_target_networks()
if isinstance(mod, TargetNetworkAPI)
else None
)
)

# The current kl coefficients per module as (framework specific) tensor
Expand All @@ -41,10 +49,26 @@ def build(self):
)
)

@override(Learner)
def add_module(
self,
*,
module_id: ModuleID,
module_spec: SingleAgentRLModuleSpec,
config_overrides: Optional[Dict] = None,
new_should_module_be_updated: Optional[ShouldModuleBeUpdatedFn] = None,
) -> MultiAgentRLModuleSpec:
marl_spec = super().add_module(module_id=module_id)
# Create target networks for added Module, if applicable.
if isinstance(self.module[module_id].unwrapped(), TargetNetworkAPI):
self.module[module_id].unwrapped().make_target_networks()
return marl_spec

@override(IMPALALearner)
def remove_module(self, module_id: str):
super().remove_module(module_id)
def remove_module(self, module_id: str) -> MultiAgentRLModuleSpec:
marl_spec = super().remove_module(module_id)
self.curr_kl_coeffs_per_module.pop(module_id)
return marl_spec

@override(Learner)
def after_gradient_based_update(self, *, timesteps: Dict[str, Any]) -> None:
Expand Down Expand Up @@ -77,7 +101,15 @@ def after_gradient_based_update(self, *, timesteps: Dict[str, Any]) -> None:
timestep - self.metrics.peek(last_update_ts_key, default=0)
>= config.target_update_frequency
):
module.sync_target_networks(tau=config.tau)
for (
main_net,
target_net,
) in module.unwrapped().get_target_network_pairs():
update_target_network(
main_net=main_net,
target_net=target_net,
tau=config.tau,
)
# Increase lifetime target network update counter by one.
self.metrics.log_value((module_id, NUM_TARGET_UPDATES), 1, reduce="sum")
# Update the (single-value -> window=1) last updated timestep metric.
Expand Down
47 changes: 25 additions & 22 deletions rllib/algorithms/appo/appo_rl_module.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,32 @@
"""
This file holds framework-agnostic components for APPO's RLModules.
"""

import abc
from typing import Any, Dict, List, Tuple

from ray.rllib.algorithms.ppo.ppo_rl_module import PPORLModule
from ray.rllib.core.rl_module.rl_module_with_target_networks_interface import (
RLModuleWithTargetNetworksInterface,
)
from ray.rllib.utils.annotations import ExperimentalAPI
from ray.rllib.algorithms.appo.appo import OLD_ACTION_DIST_LOGITS_KEY
from ray.rllib.core.learner.utils import make_target_network
from ray.rllib.core.models.base import ACTOR
from ray.rllib.core.models.tf.encoder import ENCODER_OUT
from ray.rllib.core.rl_module.apis.target_network_api import TargetNetworkAPI
from ray.rllib.utils.typing import NetworkType

from ray.rllib.utils.annotations import override

# TODO (simon): Write a light-weight version of this class for the `TFRLModule`

class APPORLModule(PPORLModule, TargetNetworkAPI, abc.ABC):
@override(TargetNetworkAPI)
def make_target_networks(self):
self._old_encoder = make_target_network(self.encoder)
self._old_pi = make_target_network(self.pi)

@ExperimentalAPI
class APPORLModule(PPORLModule, RLModuleWithTargetNetworksInterface, abc.ABC):
def setup(self):
super().setup()
@override(TargetNetworkAPI)
def get_target_network_pairs(self) -> List[Tuple[NetworkType, NetworkType]]:
return [
(self.encoder, self._old_encoder),
(self.pi, self._old_pi),
]

# If the module is not for inference only, create the target networks.
if not self.config.inference_only:
catalog = self.config.get_catalog()
# Old pi and old encoder are the "target networks" that are used for
# the stabilization of the updates of the current pi and encoder.
self.old_pi = catalog.build_pi_head(framework=self.framework)
self.old_encoder = catalog.build_actor_critic_encoder(
framework=self.framework
)
@override(TargetNetworkAPI)
def forward_target(self, batch: Dict[str, Any]) -> Dict[str, Any]:
old_pi_inputs_encoded = self._old_encoder(batch)[ENCODER_OUT][ACTOR]
old_action_dist_logits = self._old_pi(old_pi_inputs_encoded)
return {OLD_ACTION_DIST_LOGITS_KEY: old_action_dist_logits}
50 changes: 1 addition & 49 deletions rllib/algorithms/appo/tf/appo_tf_rl_module.py
Original file line number Diff line number Diff line change
@@ -1,57 +1,9 @@
from typing import Dict, List

from ray.rllib.algorithms.appo.appo import OLD_ACTION_DIST_LOGITS_KEY
from ray.rllib.algorithms.appo.appo_rl_module import APPORLModule
from ray.rllib.algorithms.ppo.tf.ppo_tf_rl_module import PPOTfRLModule
from ray.rllib.core.columns import Columns
from ray.rllib.core.models.base import ACTOR
from ray.rllib.core.models.tf.encoder import ENCODER_OUT
from ray.rllib.core.rl_module.rl_module_with_target_networks_interface import (
RLModuleWithTargetNetworksInterface,
)
from ray.rllib.utils.annotations import override
from ray.rllib.utils.framework import try_import_tf

_, tf, _ = try_import_tf()


class APPOTfRLModule(PPOTfRLModule, APPORLModule):
@override(PPOTfRLModule)
def setup(self):
super().setup()

# If the module is not for inference only, set up the target networks.
if not self.config.inference_only:
self.old_pi.set_weights(self.pi.get_weights())
self.old_encoder.set_weights(self.encoder.get_weights())
self.old_pi.trainable = False
self.old_encoder.trainable = False

@override(RLModuleWithTargetNetworksInterface)
def sync_target_networks(self, tau: float) -> None:
for target_network, current_network in [
(self.old_pi, self.pi),
(self.old_encoder, self.encoder),
]:
for old_var, current_var in zip(
target_network.variables, current_network.variables
):
updated_var = tau * current_var + (1.0 - tau) * old_var
old_var.assign(updated_var)

@override(PPOTfRLModule)
def output_specs_train(self) -> List[str]:
return [
Columns.ACTION_DIST_INPUTS,
Columns.VF_PREDS,
OLD_ACTION_DIST_LOGITS_KEY,
]

@override(PPOTfRLModule)
def _forward_train(self, batch: Dict):
outs = super()._forward_train(batch)
batch = batch.copy()
old_pi_inputs_encoded = self.old_encoder(batch)[ENCODER_OUT][ACTOR]
old_action_dist_logits = tf.stop_gradient(self.old_pi(old_pi_inputs_encoded))
outs[OLD_ACTION_DIST_LOGITS_KEY] = old_action_dist_logits
return outs
pass
55 changes: 8 additions & 47 deletions rllib/algorithms/appo/torch/appo_torch_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,7 @@
)
from ray.rllib.core.columns import Columns
from ray.rllib.core.learner.learner import POLICY_LOSS_KEY, VF_LOSS_KEY, ENTROPY_KEY
from ray.rllib.core.learner.torch.torch_learner import TorchLearner
from ray.rllib.core.rl_module.marl_module import MultiAgentRLModule
from ray.rllib.core.rl_module.torch.torch_rl_module import (
TorchDDPRLModuleWithTargetNetworksInterface,
TorchRLModule,
)
from ray.rllib.core.rl_module.rl_module_with_target_networks_interface import (
RLModuleWithTargetNetworksInterface,
)
from ray.rllib.core.rl_module.apis.target_network_api import TargetNetworkAPI
from ray.rllib.utils.annotations import override
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.numpy import convert_to_numpy
Expand All @@ -44,15 +36,18 @@ def compute_loss_for_module(
fwd_out: Dict[str, TensorType],
) -> TensorType:

module = self.module[module_id].unwrapped()
assert isinstance(module, TargetNetworkAPI)

values = fwd_out[Columns.VF_PREDS]
action_dist_cls_train = (
self.module[module_id].unwrapped().get_train_action_dist_cls()
)

action_dist_cls_train = module.get_train_action_dist_cls()
target_policy_dist = action_dist_cls_train.from_logits(
fwd_out[Columns.ACTION_DIST_INPUTS]
)

old_target_policy_dist = action_dist_cls_train.from_logits(
fwd_out[OLD_ACTION_DIST_LOGITS_KEY]
module.forward_target(batch)[OLD_ACTION_DIST_LOGITS_KEY]
)
old_target_policy_actions_logp = old_target_policy_dist.logp(
batch[Columns.ACTIONS]
Expand Down Expand Up @@ -184,40 +179,6 @@ def compute_loss_for_module(
# Return the total loss.
return total_loss

@override(TorchLearner)
def _make_modules_ddp_if_necessary(self) -> None:
"""Logic for (maybe) making all Modules within self._module DDP.
This implementation differs from the super's default one in using the special
TorchDDPRLModuleWithTargetNetworksInterface wrapper, instead of the default
TorchDDPRLModule one.
"""

# If the module is a MultiAgentRLModule and nn.Module we can simply assume
# all the submodules are registered. Otherwise, we need to loop through
# each submodule and move it to the correct device.
# TODO (Kourosh): This can result in missing modules if the user does not
# register them in the MultiAgentRLModule. We should find a better way to
# handle this.
if self._distributed:
# Single agent module: Convert to
# `TorchDDPRLModuleWithTargetNetworksInterface`.
if isinstance(self._module, RLModuleWithTargetNetworksInterface):
self._module = TorchDDPRLModuleWithTargetNetworksInterface(self._module)
# Multi agent module: Convert each submodule to
# `TorchDDPRLModuleWithTargetNetworksInterface`.
else:
assert isinstance(self._module, MultiAgentRLModule)
for key in self._module.keys():
sub_module = self._module[key]
if isinstance(sub_module, TorchRLModule):
# Wrap and override the module ID key in self._module.
self._module.add_module(
key,
TorchDDPRLModuleWithTargetNetworksInterface(sub_module),
override=True,
)

@override(APPOLearner)
def _update_module_kl_coeff(self, module_id: ModuleID, config: APPOConfig) -> None:
# Update the current KL value based on the recently measured value.
Expand Down
44 changes: 0 additions & 44 deletions rllib/algorithms/appo/torch/appo_torch_rl_module.py
Original file line number Diff line number Diff line change
@@ -1,53 +1,9 @@
from typing import Dict, List, Tuple

from ray.rllib.algorithms.appo.appo import (
OLD_ACTION_DIST_LOGITS_KEY,
)
from ray.rllib.algorithms.appo.appo_rl_module import APPORLModule
from ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module import PPOTorchRLModule
from ray.rllib.core.columns import Columns
from ray.rllib.core.models.base import ACTOR
from ray.rllib.core.models.tf.encoder import ENCODER_OUT
from ray.rllib.core.rl_module.rl_module_with_target_networks_interface import (
RLModuleWithTargetNetworksInterface,
)
from ray.rllib.utils.annotations import override
from ray.rllib.utils.typing import NetworkType


class APPOTorchRLModule(PPOTorchRLModule, APPORLModule):
@override(PPOTorchRLModule)
def setup(self):
super().setup()

# If the module is not for inference only, update the target networks.
if not self.config.inference_only:
self.old_pi.load_state_dict(self.pi.state_dict())
self.old_encoder.load_state_dict(self.encoder.state_dict())
# We do not train the targets.
self.old_pi.requires_grad_(False)
self.old_encoder.requires_grad_(False)

@override(RLModuleWithTargetNetworksInterface)
def get_target_network_pairs(self) -> List[Tuple[NetworkType, NetworkType]]:
return [(self.old_pi, self.pi), (self.old_encoder, self.encoder)]

@override(PPOTorchRLModule)
def output_specs_train(self) -> List[str]:
return [
Columns.ACTION_DIST_INPUTS,
OLD_ACTION_DIST_LOGITS_KEY,
Columns.VF_PREDS,
]

@override(PPOTorchRLModule)
def _forward_train(self, batch: Dict):
outs = super()._forward_train(batch)
old_pi_inputs_encoded = self.old_encoder(batch)[ENCODER_OUT][ACTOR]
old_action_dist_logits = self.old_pi(old_pi_inputs_encoded)
outs[OLD_ACTION_DIST_LOGITS_KEY] = old_action_dist_logits
return outs

@override(PPOTorchRLModule)
def _set_inference_only_state_dict_keys(self) -> None:
# Get the model_parameters from the `PPOTorchRLModule`.
Expand Down
22 changes: 20 additions & 2 deletions rllib/algorithms/dqn/dqn_rainbow_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Any, Dict

from ray.rllib.core.learner.learner import Learner
from ray.rllib.core.learner.utils import update_target_network
from ray.rllib.connectors.common.add_observations_from_episodes_to_batch import (
AddObservationsFromEpisodesToBatch,
)
Expand Down Expand Up @@ -42,8 +43,13 @@ def build(self) -> None:
super().build()

# Initially sync target networks (w/ tau=1.0 -> full overwrite).
# TODO (sven): Use TargetNetworkAPI as soon as DQN implements it.
self.module.foreach_module(
lambda mid, module: module.sync_target_networks(tau=1.0)
lambda mid, module: (
module.sync_target_networks(tau=1.0)
if hasattr(module, "sync_target_networks")
else None
)
)

# Prepend a NEXT_OBS from episodes to train batch connector piece (right
Expand Down Expand Up @@ -72,7 +78,19 @@ def after_gradient_based_update(self, *, timesteps: Dict[str, Any]) -> None:
timestep - self.metrics.peek(last_update_ts_key, default=0)
>= config.target_network_update_freq
):
module.sync_target_networks(tau=config.tau)
# TODO (sven): Use TargetNetworkAPI as soon as DQN implements it.
if hasattr(module, "sync_target_networks"):
module.sync_target_networks(tau=config.tau)
else:
for (
main_net,
target_net,
) in module.unwrapped().get_target_network_pairs():
update_target_network(
main_net=main_net,
target_net=target_net,
tau=config.tau,
)
# Increase lifetime target network update counter by one.
self.metrics.log_value((module_id, NUM_TARGET_UPDATES), 1, reduce="sum")
# Update the (single-value -> window=1) last updated timestep metric.
Expand Down
6 changes: 4 additions & 2 deletions rllib/algorithms/dqn/torch/dqn_rainbow_torch_rl_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,15 @@ def setup(self):
# space is a flat space we can use a noisy encoder.
self.uses_noisy_encoder = isinstance(self.encoder, TorchNoisyMLPEncoder)

# If we have target networks ,we need to make them not trainable.
# If not an inference-only module (e.g., for evaluation), set up the
# parameter names to be removed or renamed when syncing from the state dict
# when syncing.
if not self.config.inference_only:
# If we have target networks ,we need to make them not trainable.
self.target_encoder.requires_grad_(False)
self.af_target.requires_grad_(False)
if self.uses_dueling:
self.vf_target.requires_grad_(False)

# Set the expected and unexpected keys for the inference-only module.
self._set_inference_only_state_dict_keys()

Expand Down
Loading

0 comments on commit 2a8b122

Please sign in to comment.