Skip to content
This repository has been archived by the owner on Feb 13, 2025. It is now read-only.

Commit

Permalink
Merge pull request #68 from facebookresearch/samvelyan/seeding
Browse files Browse the repository at this point in the history
Fixing the seeding issue
  • Loading branch information
samvelyan authored Dec 9, 2022
2 parents 2054e7f + 36b174f commit 574004c
Showing 1 changed file with 10 additions and 4 deletions.
14 changes: 10 additions & 4 deletions minihack/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,14 @@ def __init__(
environment as a dictionary. Defaults to
``minihack.base.MH_DEFAULT_OBS_KEYS``.
seeds (list or None):
A list of random seeds for sampling episodes. If none, the
entire level distribution is used. Defaults to None.
A list of integers used as level seeds for sampling
episodes. The reset()` function samples a seed from this list
uniformly at random and uses it for setting the level.
When the ``sample_seed`` argument of the reset function is
set to False, a random level will not be sampled from this list
during environment resetting.
If None, the entire level distribution is used.
Defaults to None.
penalty_mode (str):
The name of the mode for calculating the time step penalty.
Can be ``constant``, ``exp``, ``square``, ``linear``, or
Expand Down Expand Up @@ -319,10 +325,10 @@ def _get_obs_space_dict(self, space_dict):

return obs_space_dict

def reset(self, *args, **kwargs):
def reset(self, *args, sample_seed=True, **kwargs):
if self.reward_manager is not None:
self.reward_manager.reset()
if self._level_seeds is not None:
if sample_seed and self._level_seeds is not None:
seed = random.choice(self._level_seeds)
self.seed(seed, seed, reseed=False)
return super().reset(*args, **kwargs)
Expand Down

0 comments on commit 574004c

Please sign in to comment.