Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added Action masking for Space.sample() #2906

Merged
merged 26 commits into from
Jun 26, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
c390a6b
Allows a new RNG to be generated with seed=-1 and updated env_checker…
pseudo-rnd-thoughts Jun 8, 2022
654476e
Revert "fixed `gym.vector.make` where the checker was being applied i…
pseudo-rnd-thoughts Jun 8, 2022
ea110ad
Merge branch 'openai:master' into master
pseudo-rnd-thoughts Jun 11, 2022
2e5dc9c
Remove bad pushed commits
pseudo-rnd-thoughts Jun 13, 2022
717bc1f
Merge branch 'openai:master' into master
pseudo-rnd-thoughts Jun 13, 2022
7e73d04
Merge branch 'openai:master' into master
pseudo-rnd-thoughts Jun 15, 2022
5743cd9
Merge branch 'openai:master' into master
pseudo-rnd-thoughts Jun 16, 2022
400b1e9
Fixed spelling in core.py
pseudo-rnd-thoughts Jun 17, 2022
4281c76
Pins pytest to the last py 3.6 version
pseudo-rnd-thoughts Jun 17, 2022
5dee690
Add support for action masking in Space.sample(mask=...)
pseudo-rnd-thoughts Jun 17, 2022
bc6ab4a
Fix action mask
pseudo-rnd-thoughts Jun 17, 2022
1700e9d
Fix action_mask
pseudo-rnd-thoughts Jun 17, 2022
7f46df2
Fix action_mask
pseudo-rnd-thoughts Jun 17, 2022
cd91007
Added docstrings, fixed bugs and added taxi examples
pseudo-rnd-thoughts Jun 19, 2022
be4063e
Fixed bugs
pseudo-rnd-thoughts Jun 19, 2022
2f14eb7
Add tests for sample
pseudo-rnd-thoughts Jun 20, 2022
f52d5d5
Add docstrings and test space sample mask Discrete and MultiBinary
pseudo-rnd-thoughts Jun 20, 2022
5e699e1
Add MultiDiscrete sampling and tests
pseudo-rnd-thoughts Jun 21, 2022
634da12
Remove sample mask from graph
pseudo-rnd-thoughts Jun 21, 2022
f85055c
Update gym/spaces/multi_discrete.py
pseudo-rnd-thoughts Jun 23, 2022
4a4b166
Updates based on Marcus28 and jjshoots for Graph.py
pseudo-rnd-thoughts Jun 23, 2022
eb63c62
Updates based on Marcus28 and jjshoots for Graph.py
pseudo-rnd-thoughts Jun 23, 2022
8918914
jjshoot review
pseudo-rnd-thoughts Jun 24, 2022
a53f0e7
jjshoot review
pseudo-rnd-thoughts Jun 25, 2022
8e71e46
Update assert check
pseudo-rnd-thoughts Jun 25, 2022
875ab44
Update type hints
pseudo-rnd-thoughts Jun 25, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Remove sample mask from graph
  • Loading branch information
pseudo-rnd-thoughts committed Jun 21, 2022
commit 634da1218c9e34131561264379948483d3a33eb2
25 changes: 11 additions & 14 deletions gym/spaces/graph.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Implementation of a space that represents graph information where nodes and edges can be represented with euclidean space."""
from collections import namedtuple
from typing import NamedTuple, Optional, Sequence, Tuple, Union
from typing import NamedTuple, Optional, Sequence, Union

import numpy as np

Expand Down Expand Up @@ -92,39 +92,36 @@ def _generate_sample_space(
f"Only Box and Discrete can be accepted as a base_space, got {type(base_space)}, you should not have gotten this error."
)

def sample(
self, mask: Optional[Tuple[Optional[np.ndarray], Optional[np.ndarray]]] = None
) -> NamedTuple:
def sample(self, mask: None = None) -> NamedTuple:
"""Generates a single sample graph with num_nodes between 1 and 10 sampled from the Graph.

Args:
mask: An optional tuple for the node space mask and the edge space mask (only valid for Discrete spaces).
The expected shape for the node mask is ``node_space.n`` and edge mask is ``edge_space.n``.
mask: As the number of nodes to determined during sample, it is not possible to know the mask beforehand.

Returns:
A NamedTuple representing a graph with attributes .nodes, .edges, and .edge_links.
"""
node_mask, edge_mask = mask if mask is not None else (None, None)
if mask is not None:
raise NotImplementedError(
"Graph.sample(mask) is not implemented as the number of nodes is determined within the function."
)

num_nodes = self.np_random.integers(low=1, high=10)

# we only have edges when we have at least 2 nodes
num_edges = 0
if num_nodes > 1:
# maximal number of edges is (n*n) allowing self connections and two way is allowed
# maximal number of edges is (n*n) allowing self connections and two-way is allowed
num_edges = self.np_random.integers(num_nodes * num_nodes)

node_sample_space = self._generate_sample_space(self.node_space, num_nodes)
edge_sample_space = self._generate_sample_space(self.edge_space, num_edges)

sampled_nodes = (
node_sample_space.sample(node_mask)
if node_sample_space is not None
else None
node_sample_space.sample() if node_sample_space is not None else None
)
sampled_edges = (
edge_sample_space.sample(edge_mask)
if edge_sample_space is not None
else None
edge_sample_space.sample() if edge_sample_space is not None else None
)

sampled_edge_links = None
Expand Down
12 changes: 11 additions & 1 deletion tests/spaces/test_spaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,6 @@ def _update_observed_frequency(obs_sample, obs_freq):

def _exp_freq_fn(_dim: int, _mask: np.ndarray):
if np.any(_mask == 1):
print(f"{_dim=}, {_mask=}")
assert _dim == len(_mask)
return np.ones(_dim) * (n_trials / np.sum(_mask)) * _mask
else:
Expand Down Expand Up @@ -420,6 +419,17 @@ def _chi_squared_test(dim, _mask, exp_freq, obs_freq):
"b": np.array([0, 1, 1], dtype=np.int8),
},
),
(Graph(node_space=Discrete(5), edge_space=Discrete(3)), None),
(
Graph(node_space=Discrete(3), edge_space=Box(low=0, high=1, shape=(5,))),
None,
),
(
Graph(
node_space=Box(low=-100, high=100, shape=(3,)), edge_space=Discrete(3)
),
None,
),
],
)
def test_composite_space_sample_mask(space, mask):
Expand Down