[docs]classDiscrete(gymnasium.spaces.Discrete,Space):"""A discrete space in :math:`\\{ start, start + 1, \dots, start + n-1 \\}`. Example:: >>> Discrete(2) """_RE=re.compile(r"\A\s*?Discrete\((\d+)\)\s*\Z")def__init__(self,n:int|np.integer[Any],seed:int|np.random.Generator|None=None,start:int|np.integer[Any]=0,):gymnasium.spaces.Discrete.__init__(self,n,seed,start)Space.__init__(self)
[docs]defto_vector(self,data:np.ndarray,**kwargs)->np.ndarray:"""Flatten the discrete data to a ndarray of size self.n"""assert(data.shape==(1,)ordata.shape==1ordata.shape==()),f"Expected shape (1,) or 1 or (); Got {data.shape} instead"transformed=np.zeros(self.n)transformed[data.item()]=1returnnp.array(transformed)
[docs]defreshape_to_space(self,value:Any,**kwargs)->np.ndarray:"""Reshape the flat representation of data into a single number :kwargs: dtype: The dtype of the returned array. default: float """ifnp.isscalar(value)ornp.ndim(value)==0:returnnp.array(value)as_array=np.fromiter(value,kwargs.get("dtype",float),self.n)assert(len(as_array)==self.n),f"Expected {self.n} data points; Got {len(as_array)} instead"returnnp.array(as_array.argmax())
[docs]@classmethod@functools.cachedeffrom_string(cls,s):match=Discrete._RE.match(s)ifnotmatchornotmatch[1]:raiseRuntimeError("String '%s' did not match '%s'"%(s,Discrete._RE))returnDiscrete(int(match[1]))