from __future__ import annotations
from typing import List
import logging
from itertools import product
from palaestrai.core.protocol import (
AgentUpdateRequest,
AgentUpdateResponse,
EnvironmentResetNotificationRequest,
EnvironmentResetNotificationResponse,
EnvironmentResetRequest,
EnvironmentResetResponse,
EnvironmentUpdateRequest,
EnvironmentUpdateResponse,
SimulationControllerTerminationRequest,
SimulationControllerTerminationResponse,
)
from .simulation_controller import SimulationController
LOG = logging.getLogger(__name__)
[docs]
class VanillaSimController(SimulationController):
"""
This is our vanilla controller. With this simulation,
environment and agent(s) work alternately and not in
parallel. It is not a continuous simulation.
After each simulation step (independent if done
by environment or agent) the Termination Condition
is called.
"""
def __init__(
self,
rungov_connection,
sim_connection,
agent_conductor_ids,
environment_conductor_ids,
agents,
termination_conditions,
mode,
):
super().__init__(
rungov_connection,
sim_connection,
agent_conductor_ids,
environment_conductor_ids,
agents,
termination_conditions,
mode,
)
[docs]
@staticmethod
def handle_env_update(responses: List[EnvironmentUpdateResponse]):
"""
This method will process a list of environment update responses
and will combine the information to combined lists. Currently
the termination variable is global because we assume, that
the run terminates if at least one environment is env
terminates.
Parameters
----------
responses = list of EnvironmentUpdateResponse
Returns
-------
sensors = list of all sensor_information of all available environments
rewards = list of all environment rewards
termination = boolean which is true if one environment has terminated
"""
sensors = []
rewards = []
termination = False
for response in responses:
sensors.extend(response.sensors)
rewards.extend(response.rewards)
if response.done:
termination = True
return sensors, rewards, termination
[docs]
@staticmethod
def handle_env_reset(responses: list):
"""
This method will process a list of environment reset responses
and will combine the information to combined lists.
Parameters
----------
responses = list of EnvironmentResetResponse
Returns
-------
sensors = list of all sensor_information of all available environments
"""
sensors = []
for response in responses:
sensors.extend(response.sensors)
return sensors
[docs]
@staticmethod
def agent_update(responses: list) -> list:
"""
This method combines all actuator_information of all Agents
and creates on list.
Parameters
----------
responses = List of AgentUpdateResponses
Returns
-------
actuators which is a list of actuator_information
"""
actuators = []
for response in responses:
actuators.extend(response.actuators)
return actuators
[docs]
async def simulation_shutdown(
self, env_termination: bool, rewards: list, additional_results
):
"""
This method will be called when the simulation has terminated
it will send a SimControllerTerminationRequest to the runGov.
The RunGov will respond and will tell if it will be a complete
or partial shutdown.
The complete shutdown includes the conductors while a partial
shutdown is a reset which just deletes the muscle(s) and env(s)
Parameters
----------
env_termination : bool if the environment has terminated
rewards : list of rewards to show the current performance
additional_results : for any additional information
Returns
-------
"""
msg = SimulationControllerTerminationRequest(
sender_simulation_controller_id=self.uid,
receiver_run_governor_id=self.rg_id,
experiment_run_id=self.experiment_run_id,
environment_terminated=env_termination,
last_reward=rewards,
additional_results=additional_results,
)
LOG.debug(
"SimulationController %s(id=0x%x, uid=%s) sending "
"SimulationControllerTerminationRequest(experiment_run_id=%s).",
self.__class__,
id(self),
self.uid,
self.experiment_run_id,
)
response = await self.run_gov_client.send(self.rg_id, msg)
if not isinstance(response, SimulationControllerTerminationResponse):
LOG.critical(
"SimulationController %s(id=0x%x, uid=%s) "
"waited for SimulationControllerTerminationResponse, but got "
"%s instead. Dying without honor, trusting the RunGovernor "
"to handle this disgrace.",
self.__class__,
id(self),
self.uid,
response,
)
await self.stop_simulation(True)
LOG.debug(
"SimulationController %s(id=0x%x, uid=%s) received %s.",
self.__class__,
id(self),
self.uid,
response,
)
if response.complete_shutdown:
await self.stop_simulation(response.complete_shutdown)
# TODO: if episodes > 1: restart
else:
await self.agent_shutdown(response.complete_shutdown)
await self.env_shutdown()
[docs]
async def get_env_update(self, env, actuators):
"""
Sends an EnvironmentUpdateRequest to one env
and collects the Response.
The vanilla simController sends all actuators to all envs
and the env has to select the own actuators. A access list
could be needed if two envs of the same type are used.
Parameters
----------
env : id of the environment
actuators : list of actuatorinformation
Returns
-------
response : EnvironmentUpdateResponse
"""
LOG.debug(
"SimulationController %s(id=0x%x, uid=%s) "
"starting EnvironmentUpdateRequest(experiment_run_id=%s, "
"env=%s, actuators=%s).",
self.__class__,
id(self),
self.uid,
self.experiment_run_id,
str(env),
actuators,
)
msg = EnvironmentUpdateRequest(
experiment_run_id=self.experiment_run_id,
experiment_run_instance_id=self._experiment_run_instance_id,
experiment_run_phase=self._experiment_run_phase,
sender_simulation_controller=self.uid,
receiver_environment=str(env),
actuators=actuators,
)
response = await self.client.send(bytes(str(env), "ascii"), msg)
LOG.debug(
"SimulationController %s(id=0x%x, uid=%s) "
"received EnvironmentUpdateResponse: %s.",
self.__class__,
id(self),
self.uid,
response,
)
if isinstance(response, EnvironmentUpdateResponse):
return response
else:
LOG.error(
"SimulationController %s(id=0x%x, uid=%s) expected "
"EnvironmentUpdateResponse, but got %s instead; ignoring.",
self.__class__,
id(self),
self.uid,
response,
)
return None
[docs]
async def get_agent_updates(self, rewards, env_termination, simtimes):
"""Fetches actions from agents.
This method sends an :class:`~AgentUpdateRequest` to each agent and collects
the responses, which will be returned.
Parameters
----------
rewards : List[RewardInformation]
List of environment rewards
env_termination : bool
``True`` if environment has terminated, ``False`` if not
simtimes : Dict[str, palaestrai.types.SimTime]
Contains time values from the environment. It maps environment UIDs
to :class:`~SimTime`.
Returns
-------
List[AgentUpdateResponse]
List of response messages containing information about the agent's actions
"""
LOG.debug(
"SimulationController %s(id=0x%x, uid=%s) "
"requesting updates from Agents(uid=%s).",
self.__class__,
id(self),
self.uid,
[str(agent.uid) for agent in self._agents],
)
messages = [
AgentUpdateRequest(
sender_simulation_controller_id=self.uid,
receiver_rollout_worker_id=str(agent.uid),
experiment_run_id=self.experiment_run_id,
experiment_run_instance_id=self._experiment_run_instance_id,
experiment_run_phase=self._experiment_run_phase,
sensors=self.access_list([s.uid for s in agent.sensors], [])[
0
], # TODO: quite hacky,
actuators=agent.actuators,
rewards=rewards,
is_terminal=env_termination,
mode=self.mode,
simtimes=simtimes,
)
for agent in self._agents
]
updates = []
responses = await self.client.send_async(
list(zip([str(agent.uid) for agent in self._agents], messages))
)
for agent_id, response in responses.items():
if isinstance(response, AgentUpdateResponse):
LOG.debug(
"SimulationController %s(id=0x%x, uid=%s) "
"received AgentUpdateResponse from Agent(uid=%s).",
self.__class__,
id(self),
self.uid,
agent_id,
)
# TODO: this is quite hacky
if len(response.actuators) > 0:
for agent in self._agents:
if agent.uid == agent_id:
agent.actuators = response.actuators
break
updates.append(response)
else:
LOG.error(
"SimulationController %s(id=0x%x, uid=%s) expected "
"AgentUpdateResponse, but got %s instead; ignoring.",
self.__class__,
id(self),
self.uid,
response,
)
updates.append(None)
return [term for term in updates if term]
[docs]
async def simulation(self):
"""
This is the abstract method implementation of the
simulation task. The vanilla sim controller simulation
start by asking all environments for sensor information.
These information will be sent to the agent(s) which
will respond which their actuator-values.
From there on it will be a ping pong between environment(s)
and agent(s). All available information will be exchanged.
Both, environment as well as agent information, can be
used for the termination.
Returns
-------
"""
termination = False
actuators = self._actuators
env_termination = False
rewards = []
LOG.info(
"SimulationController %s(id=0x%x, uid=%s) starting simulation.",
self.__class__,
id(self),
self.uid,
)
while True:
LOG.debug(
"SimulationController %s(id=0x%x, uid=%s) "
"running new iteration for experiment_run_id=%s; "
"termination: %s.",
self.__class__,
id(self),
self.uid,
self.experiment_run_id,
self.ex_termination,
)
# after we terminated we dont need to ask for new updates
if not termination:
responses = [
term
for term in [
await self.get_env_update(eid, actuators)
for eid in self.env_ids
]
if term
]
termination = self._check_termination(responses)
(
self._sensors,
rewards,
env_termination,
) = self.handle_env_update(responses)
simtimes = {
r.sender_environment_id: r.simtime for r in responses
}
else:
# after the reset we work with the default parameters
termination = False
env_termination = False
rewards = []
simtimes = {}
responses = await self.get_agent_updates(
rewards, env_termination, simtimes
)
termination = termination or self._check_termination(responses)
actuators = self.collect_actuators(agent_responses=responses)
if not await self.continue_simulation(termination, rewards, None):
break
LOG.info(
"SimulationController %s(id=0x%x, uid=%s) " "finished simulation.",
self.__class__,
id(self),
self.uid,
)
def _check_termination(self, responses):
"""
aggregates termination flags of given responses and checks against
self.termination_conditions to determin if a termination state has been reached
"""
return any(
[
term[0].check_termination(term[1])
for term in product(self.termination_conditions, responses)
]
)
[docs]
async def continue_simulation(
self, termination, rewards, additional_results
):
if self.ex_termination:
LOG.info(
"SimulationController %s(id=0x%x, uid=%s) encountered "
"external termination. Performing complete shutdown!",
self.__class__,
id(self),
self.uid,
)
await self.stop_simulation(True)
return False
elif termination:
LOG.info(
"SimulationController %s(id=0x%x, uid=%s) encountered "
"internal termination.",
self.__class__,
id(self),
self.uid,
)
response = await self.send_termination_request(
rewards, additional_results
)
if response.complete_shutdown:
await self.stop_simulation(True)
return False
elif response.restart:
LOG.info(
"SimulationController %s/id=0x%x, uid=%s) initiating "
"restart.",
self.__class__,
id(self),
self.uid,
)
await self.perform_restart()
return True
else:
await self.agent_shutdown(response.complete_shutdown)
await self.env_shutdown()
# TODO: This situation is not handled in the main loop
# Therefore, we will just break the simulation loop.
return False
else:
return True
[docs]
async def send_termination_request(
self, rewards: list, additional_results
):
msg = SimulationControllerTerminationRequest(
sender_simulation_controller_id=self.uid,
receiver_run_governor_id=self.rg_id,
experiment_run_id=self.experiment_run_id,
environment_terminated=True,
last_reward=rewards,
additional_results=additional_results,
)
LOG.debug(
"SimulationController %s(id=0x%x, uid=%s) sending "
"SimulationControllerTerminationRequest(experiment_run_id=%s).",
self.__class__,
id(self),
self.uid,
self.experiment_run_id,
)
response = await self.run_gov_client.send(self.rg_id, msg)
if not isinstance(response, SimulationControllerTerminationResponse):
LOG.critical(
"SimulationController %s(id=0x%x, uid=%s) "
"waited for SimulationControllerTerminationResponse, but got "
"%s instead. Dying without honor, trusting the RunGovernor "
"to handle this disgrace.",
self.__class__,
id(self),
self.uid,
response,
)
await self.stop_simulation(True)
LOG.debug(
"SimulationController %s(id=0x%x, uid=%s) received %s.",
self.__class__,
id(self),
self.uid,
response,
)
return response
[docs]
def collect_actuators(self, agent_responses: list):
"""
collect_actuators takes a list of agent_responses
and combines all available actuators to one list.
Parameters
----------
agent_responses : list[agent_responses]
Returns
-------
actuators : list[actuator]
"""
actuators = []
for response in agent_responses:
# TODO: Workaround, works only for single env
actuators.extend(response.actuators)
return actuators
[docs]
def access_list(self, sensor_list, actuator_list):
"""
access_list takes a list of sensors and actuators
and checks if they are part of the available
sensors/actuators. If that is the case they
will be returned.
Parameters
----------
sensor_list : list
a list with sensor IDs.
actuator_list : list
a list with actuator IDs.
Returns
-------
sensor_list : list
actuator_list : list
"""
return (
[s for s in self._sensors if s.uid in sensor_list],
[a for a in self._actuators if a.uid in actuator_list],
)