Source code for palaestrai.types.tuple
from __future__ import annotations
import re
import typing
from typing import Iterable
import gymnasium
import numpy as np
from .space import Space
[docs]
class Tuple(gymnasium.spaces.Tuple, Space):
"""A tuple (i.e., product) of simpler spaces
Example usage:
self.observation_space = spaces.Tuple(Discrete(2), Discrete(3))
"""
_TUPLE_RE = re.compile(r"\A\s*?Tuple\((.+)\)\s*\Z")
_INNER_PIECE_RE = re.compile(
r"(?P<inner_rest>,\s*" r"(?P<piece>[A-Za-z]+\(.*\)))\s*\Z"
)
def __init__(
self,
spaces,
seed: int | typing.Sequence[int] | np.random.Generator | None = None,
):
gymnasium.spaces.Tuple.__init__(self, spaces, seed)
Space.__init__(self)
[docs]
def to_vector(self, data: np.ndarray, **kwargs) -> np.ndarray:
"""Flatten data using the contained spaces"""
return np.array(
[s.to_vector(data[idx]) for idx, s in enumerate(self.spaces)] # type: ignore[attr-defined]
)
[docs]
def reshape_to_space(self, value: Iterable, **kwargs) -> np.ndarray:
"""Reshape value using the contained spaces"""
as_list = np.array(value)
return np.array(
[
s.reshape_to_space(as_list[idx]) # type: ignore[attr-defined]
for idx, s in enumerate(self.spaces)
]
)
[docs]
@classmethod
def from_string(cls, s):
complete_match = Tuple._TUPLE_RE.match(s)
if not complete_match:
raise RuntimeError(
"String '%s' does not match '%s'" % (s, Tuple._TUPLE_RE)
)
inner_str = complete_match[1]
spaces = []
while len(inner_str) > 0:
match = Tuple._INNER_PIECE_RE.search(inner_str)
if match is None:
try:
spaces.append(Space.from_string(inner_str))
except:
pass # We simply ignore garbage.
break
else:
head, _, tail = inner_str.rpartition(match["inner_rest"])
inner_str = head + tail
spaces.append(Space.from_string(match["piece"]))
spaces.reverse()
return Tuple(*spaces)