Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added Action masking for Space.sample() #2906

Merged
merged 26 commits into from
Jun 26, 2022
Merged
Changes from 1 commit
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
c390a6b
Allows a new RNG to be generated with seed=-1 and updated env_checker…
pseudo-rnd-thoughts Jun 8, 2022
654476e
Revert "fixed `gym.vector.make` where the checker was being applied i…
pseudo-rnd-thoughts Jun 8, 2022
ea110ad
Merge branch 'openai:master' into master
pseudo-rnd-thoughts Jun 11, 2022
2e5dc9c
Remove bad pushed commits
pseudo-rnd-thoughts Jun 13, 2022
717bc1f
Merge branch 'openai:master' into master
pseudo-rnd-thoughts Jun 13, 2022
7e73d04
Merge branch 'openai:master' into master
pseudo-rnd-thoughts Jun 15, 2022
5743cd9
Merge branch 'openai:master' into master
pseudo-rnd-thoughts Jun 16, 2022
400b1e9
Fixed spelling in core.py
pseudo-rnd-thoughts Jun 17, 2022
4281c76
Pins pytest to the last py 3.6 version
pseudo-rnd-thoughts Jun 17, 2022
5dee690
Add support for action masking in Space.sample(mask=...)
pseudo-rnd-thoughts Jun 17, 2022
bc6ab4a
Fix action mask
pseudo-rnd-thoughts Jun 17, 2022
1700e9d
Fix action_mask
pseudo-rnd-thoughts Jun 17, 2022
7f46df2
Fix action_mask
pseudo-rnd-thoughts Jun 17, 2022
cd91007
Added docstrings, fixed bugs and added taxi examples
pseudo-rnd-thoughts Jun 19, 2022
be4063e
Fixed bugs
pseudo-rnd-thoughts Jun 19, 2022
2f14eb7
Add tests for sample
pseudo-rnd-thoughts Jun 20, 2022
f52d5d5
Add docstrings and test space sample mask Discrete and MultiBinary
pseudo-rnd-thoughts Jun 20, 2022
5e699e1
Add MultiDiscrete sampling and tests
pseudo-rnd-thoughts Jun 21, 2022
634da12
Remove sample mask from graph
pseudo-rnd-thoughts Jun 21, 2022
f85055c
Update gym/spaces/multi_discrete.py
pseudo-rnd-thoughts Jun 23, 2022
4a4b166
Updates based on Marcus28 and jjshoots for Graph.py
pseudo-rnd-thoughts Jun 23, 2022
eb63c62
Updates based on Marcus28 and jjshoots for Graph.py
pseudo-rnd-thoughts Jun 23, 2022
8918914
jjshoot review
pseudo-rnd-thoughts Jun 24, 2022
a53f0e7
jjshoot review
pseudo-rnd-thoughts Jun 25, 2022
8e71e46
Update assert check
pseudo-rnd-thoughts Jun 25, 2022
875ab44
Update type hints
pseudo-rnd-thoughts Jun 25, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Fix action_mask
  • Loading branch information
pseudo-rnd-thoughts committed Jun 17, 2022
commit 1700e9d63dd40b9537bedb0cb85132a98c61e61a
28 changes: 14 additions & 14 deletions gym/envs/toy_text/taxi.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,20 +214,24 @@ def decode(self, i):
assert 0 <= i < 5
return reversed(out)

def action_mask(self, row, col, pass_loc, dest_idx):
def action_mask(self, state: int):
"""Computes an action mask for the action space using the state information."""
mask = np.zeros(6, dtype=bool)
if row < 5:
mask = np.zeros(6, dtype=np.int8)
taxi_row, taxi_col, pass_loc, dest_idx = self.decode(state)
if taxi_row < 4:
mask[0] = 1
if row > 0:
if taxi_row > 0:
mask[1] = 1
if col < 5 and self.desc[1 + row, 2 * col + 2] == b":":
if taxi_col < 4 and self.desc[taxi_row + 1, 2 * taxi_col + 2] == b":":
mask[2] = 1
if col > 0 and self.desc[1 + row, 2 * col] == b":":
if taxi_col > 0 and self.desc[taxi_row + 1, 2 * taxi_col] == b":":
mask[3] = 1
if pass_loc < 4 and (row, col) == self.locs[pass_loc]:
if pass_loc < 4 and (taxi_row, taxi_col) == self.locs[pass_loc]:
mask[4] = 1
if pass_loc == 4 and (row, col) == self.locs[dest_idx]:
if pass_loc == 4 and (
(taxi_row, taxi_col) == self.locs[dest_idx]
or (taxi_row, taxi_col) in self.locs
):
mask[5] = 1
return mask

Expand All @@ -239,9 +243,7 @@ def step(self, a):
self.lastaction = a
self.renderer.render_step()

taxi_row, taxi_col, pass_loc, dest_idx = self.decode(s)
mask = self.action_mask(taxi_row, taxi_col, pass_loc, dest_idx)
return int(s), r, d, {"prob": p, "action_mask": mask}
return int(s), r, d, {"prob": p, "action_mask": self.action_mask(s)}

def reset(
self,
Expand All @@ -259,9 +261,7 @@ def reset(
if not return_info:
return int(self.s)
else:
taxi_row, taxi_col, pass_loc, dest_idx = self.decode(self.s)
mask = self.action_mask(taxi_row, taxi_col, pass_loc, dest_idx)
return int(self.s), {"prob": 1, "action_mask": mask}
return int(self.s), {"prob": 1, "action_mask": self.action_mask(self.s)}

def render(self, mode="human"):
if self.render_mode is not None:
Expand Down