Skip to content

Commit

Permalink
[RLlib; Offline RL] Make data pipeline better configurable and tuneab…
Browse files Browse the repository at this point in the history
…le for users. (#46777)
  • Loading branch information
simonsays1980 authored Jul 25, 2024
1 parent ff789b9 commit ecfb064
Show file tree
Hide file tree
Showing 7 changed files with 365 additions and 248 deletions.
39 changes: 35 additions & 4 deletions rllib/algorithms/algorithm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,9 @@ def __init__(self, algo_class: Optional[type] = None):
self.input_read_method = "read_parquet"
self.input_read_method_kwargs = {}
self.input_read_schema = {}
self.map_batches_kwargs = {}
self.iter_batches_kwargs = {}
self.prelearner_class = None
self.prelearner_module_synch_period = 10
self.dataset_num_iters_per_learner = None
self.input_config = {}
Expand Down Expand Up @@ -2373,6 +2376,9 @@ def offline_data(
input_read_method=NotProvided,
input_read_method_kwargs=NotProvided,
input_read_schema=NotProvided,
map_batches_kwargs=NotProvided,
iter_batches_kwargs=NotProvided,
prelearner_class=NotProvided,
prelearner_module_synch_period=NotProvided,
dataset_num_iters_per_learner=NotProvided,
input_config=NotProvided,
Expand Down Expand Up @@ -2403,10 +2409,12 @@ def offline_data(
offline data from `input_`. The default is `read_json` for JSON files.
See https://docs.ray.io/en/latest/data/api/input_output.html for more
info about available read methods in `ray.data`.
input_read_method_kwargs: kwargs for the `input_read_method`. These will be
passed into the read method without checking. If no arguments are passed
in the default argument `{'override_num_blocks': max(num_learners * 2,
2)}` is used.
input_read_method_kwargs: `kwargs` for the `input_read_method`. These will
be passed into the read method without checking. If no arguments are
passed in the default argument `{'override_num_blocks':
max(num_learners * 2, 2)}` is used. Use these `kwargs`` together with
the `map_batches_kwargs` and `iter_batches_kwargs` to tune the
performance of the data pipeline.
input_read_schema: Table schema for converting offline data to episodes.
This schema maps the offline data columns to `ray.rllib.core.columns.
Columns`: {Columns.OBS: 'o_t', Columns.ACTIONS: 'a_t', ...}. Columns in
Expand All @@ -2415,6 +2423,27 @@ def offline_data(
schema used is `ray.rllib.offline.offline_data.SCHEMA`. If your data set
contains already the names in this schema, no `input_read_schema` is
needed.
map_batches_kwargs: `kwargs` for the `map_batches` method. These will be
passed into the `ray.data.Dataset.map_batches` method when sampling
without checking. If no arguments passed in the default arguments `{
'concurrency': max(2, num_learners), 'zero_copy_batch': True}` is
used. Use these `kwargs`` together with the `input_read_method_kwargs`
and `iter_batches_kwargs` to tune the performance of the data pipeline.
iter_batches_kwargs: `kwargs` for the `iter_batches` method. These will be
passed into the `ray.data.Dataset.iter_batches` method when sampling
without checking. If no arguments are passed in, the default argument `{
'prefetch_batches': 2, 'local_buffer_shuffle_size':
train_batch_size_per_learner * 4}` is used. Use these `kwargs``
together with the `input_read_method_kwargs` and `map_batches_kwargs`
to tune the performance of the data pipeline.
prelearner_class: An optional `OfflinePreLearner` class that is used to
transform data batches in `ray.data.map_batches` used in the
`OfflineData` class to transform data from columns to batches that can
be used in the `Learner`'s `update` methods. Override the
`OfflinePreLearner` class and pass your dervied class in here, if you
need to make some further transformations specific for your data or
loss. The default is `None` which uses the base `OfflinePreLearner`
defined in `ray.rllib.offline.offline_prelearner`.
prelearner_module_synch_period: The period (number of batches converted)
after which the `RLModule` held by the `PreLearner` should sync weights.
The `PreLearner` is used to preprocess batches for the learners. The
Expand Down Expand Up @@ -2470,6 +2499,8 @@ def offline_data(
self.input_read_method_kwargs = input_read_method_kwargs
if input_read_schema is not NotProvided:
self.input_read_schema = input_read_schema
if prelearner_class is not NotProvided:
self.prelearner_class = prelearner_class
if prelearner_module_synch_period is not NotProvided:
self.prelearner_module_synch_period = prelearner_module_synch_period
if dataset_num_iters_per_learner is not NotProvided:
Expand Down
5 changes: 4 additions & 1 deletion rllib/algorithms/bc/bc.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,9 @@ class (multi-/single-learner setup) and evaluation on
batch,
minibatch_size=self.config.train_batch_size_per_learner,
num_iters=self.config.dataset_num_iters_per_learner,
**self.offline_data.iter_batches_kwargs
if self.config.num_learners > 1
else {},
)

# Log training results.
Expand Down Expand Up @@ -217,7 +220,7 @@ class (multi-/single-learner setup) and evaluation on
# Update weights - after learning on the local worker -
# on all remote workers.
with self.metrics.log_time((TIMERS, SYNCH_WORKER_WEIGHTS_TIMER)):
self.workers.sync_weights(
self.env_runner_group.sync_weights(
# Sync weights from learner_group to all EnvRunners.
from_worker_or_learner_group=self.learner_group,
policies=modules_to_update,
Expand Down
6 changes: 3 additions & 3 deletions rllib/core/learner/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1143,7 +1143,8 @@ def update_from_iterator(
*,
timesteps: Optional[Dict[str, Any]] = None,
minibatch_size: Optional[int] = None,
num_iters: int = 1,
num_iters: int = None,
**kwargs,
):
self._check_is_built()
minibatch_size = minibatch_size or 32
Expand All @@ -1162,8 +1163,7 @@ def _finalize_fn(batch: Dict[str, numpy.ndarray]) -> Dict[str, Any]:
for batch in iterator.iter_batches(
batch_size=minibatch_size,
_finalize_fn=_finalize_fn,
prefetch_batches=2,
local_shuffle_buffer_size=minibatch_size * 10,
**kwargs,
):
# Update the iteration counter.
i += 1
Expand Down
Loading

0 comments on commit ecfb064

Please sign in to comment.