from __future__ import annotations
from asyncio import InvalidStateError
import copy
from typing import TYPE_CHECKING, List, Dict, Set, Optional, Union, Any
import uuid
import signal
import logging
import asyncio
from aiomultiprocess import Process
import palaestrai.logging
from palaestrai.core.protocol import (
ExperimentRunStartRequest,
ExperimentRunStartResponse,
SimulationStartRequest,
SimulationStartResponse,
SimulationControllerTerminationRequest,
SimulationControllerTerminationResponse,
ErrorIndicator,
SimulationShutdownRequest,
SimulationShutdownResponse,
ExperimentRunShutdownRequest,
ExperimentRunShutdownResponse,
ShutdownRequest,
ShutdownResponse,
)
from palaestrai.core.protocol.agent_evaluation_req import (
AgentEvaluationRequest,
)
from palaestrai.core.protocol.agent_evaluation_rsp import (
AgentEvaluationResponse,
)
from palaestrai.core.protocol.simulation_flow_change_req import (
SimulationFlowChangeRequest,
)
from palaestrai.core.protocol.simulation_flow_change_rsp import (
SimulationFlowChangeResponse,
)
from palaestrai.logging.utils import LoggingDefaultDict
from palaestrai.types import SimulationFlowControl, Mode
from palaestrai.core import EventStateMachine as ESM
from palaestrai.core import BasicState, RuntimeConfig
from palaestrai.util import spawn_wrapper, MetadataLogFilter
from palaestrai.util.exception import ExperimentAlreadyRunningError
if TYPE_CHECKING:
import aiomultiprocess
import multiprocessing
from palaestrai.experiment import TerminationCondition
from palaestrai.experiment.experiment_run import ExperimentRun
from palaestrai.simulation import SimulationController
LOG = logging.getLogger(__name__)
[docs]
@ESM.monitor(is_mdp_worker=True)
class RunGovernor:
"""
This class implements the Run-Governor.
Upon receiving requests from the executor, a RunGovernor instance
handles a single experiment run by starting it, initialize the
simulation controllers, the environment and the agent conductors,
and, finally, shutting the experiment run down.
The RunGovernor is implemented as state machine and this class
provides the context for the distinct state classes. A freshly
initialized RunGovernor waits in the state PRISTINE until the run
method is called by the executor. See the distinct state classes
for more information.
Parameters
----------
uid : str
The universally unique ID that identifies this run governor
Attributes
----------
uid: str
The UUID of this RunGovernor
termination_condition: :class:`.TerminationCondition`
A reference to the TerminationCondition instance.
run_broker: :class:`.MajorDomoBroker`
The broker for the communication with the simulation
controller, the agents, and the environments.
experiment_run_id: str
The UUID of the current experiment run.
tasks: List[aiomultiprocess.Process]
A list of tasks the RunGovernor has started and that it has
to shut down in the end.
worker: :class:`.MajorDomoWorker`
The major domo worker for handling incoming requests
client: :class:`.MajorDomoClient`
The major domo client for sending requests to other workers.
shutdown: bool
The major kill switch of the RunGovernor. Setting this to false
will stop the RunGovernor after the current state.
state: :class:`.RunGovernorState`
Holds the current state instance. The first state is PRISTINE.
errors: List[Exception]
A list that is used to collect errors raised in the states.
"""
def __init__(self, uid: Optional[str] = None):
self.uid = uid if uid else "RunGovernor-%s" % str(uuid.uuid4())
self._state = BasicState.PRISTINE
# Experiment control:
self.current_phase: int = 0
self.current_episode_counts: Dict[str, int] = {}
self.experiment_run: Optional[ExperimentRun] = None
self._termination_condition: Optional[TerminationCondition] = None
self._evaluate_every: Optional[int] = None
self.max_eval_episodes: Optional[int] = None
self._worker: Optional[int] = None
self._has_stop_phase_signal: Optional[bool] = None
self._is_phase_shutting_down: Optional[bool] = None
self._simulation_controllers_flow_status: Optional[
Dict[str, SimulationFlowControl]
] = None
self._pausing_simulation_controllers: Optional[bool] = None
# Receiver information:
self._simulation_controllers: List[str] = []
self._environment_conductors: List[str] = []
self._agent_conductors: List[str] = []
# Receiver synchronization:
self._simulation_controllers_active: Set[str] = set()
self._simulation_controllers_pending_response: Set[str] = set()
self._future_simulation_controller_active: Optional[asyncio.Future] = (
None
)
self._simulation_controllers_down_event: Optional[asyncio.Event] = None
self._future_simulation_controllers_flow_change: Optional[
asyncio.Future
] = None
self._future_shutdown: Optional[asyncio.Future] = None
self._future_next_phase: Optional[asyncio.Future] = None
self._phase_launcher_task: Optional[asyncio.Task] = None
self._shutdown_conductors_task: Optional[asyncio.Task] = None
# Evaluation
self._evaluation_start_lock: Optional[asyncio.Lock] = None
self._eval_flow_control_change_lock: Optional[asyncio.Lock] = None
self._episode_counts_lock: Optional[asyncio.Lock] = None
self._simulation_controllers_continuation_task: Optional[
asyncio.Task
] = None
self._agent_evaluation_lock: Optional[asyncio.Lock] = None
self._future_agent_evaluation: Optional[asyncio.Future] = None
self._pending_agent_evaluation_responses: Optional[Set[str]] = None
# Subprocess management:
self._processes: List[aiomultiprocess.Process] = []
[docs]
def setup(self):
self._state = BasicState.PRISTINE
self.mdp_service = self.uid # type: ignore[attr-defined]
LOG.debug("%s ready.", self)
self._state = BasicState.INITIALIZING
@ESM.on(ExperimentRunStartRequest)
async def _handle_experiment_run_start_request(
self, request: ExperimentRunStartRequest
):
if self.experiment_run is not None:
LOG.warning(
'%s got request to start experiment run "%s", '
'but experiment run "%s" is already running. '
"Reporting error and continueing. Elect somebody "
"else instead!",
self,
self.experiment_run.uid,
request.experiment_run.uid,
)
self.stop() # type: ignore[attr-defined]
return ExperimentRunStartResponse(
sender_run_governor_id=request.receiver,
receiver_executor_id=request.sender,
experiment_run_id=request.experiment_run_id,
experiment_run=request.experiment_run,
successful=False,
error=ExperimentAlreadyRunningError(self.experiment_run),
)
self.current_phase = 0
self.experiment_run = request.experiment_run
LOG.debug("%s setting up %s", self, self.experiment_run)
try:
await self._setup_run()
LOG.addFilter(
MetadataLogFilter(experiment_run_uid=self.experiment_run.uid)
)
log_filters = RuntimeConfig().logging.get("filters", {})
log_filters["metadata"] = {
"()": "palaestrai.util.metadata_logfilter.MetadataLogFilter",
"experiment_run_uid": self.experiment_run.uid,
}
RuntimeConfig().logging["filters"] = log_filters
for k in RuntimeConfig().logging["loggers"].keys():
filters = (
RuntimeConfig().logging["loggers"][k].get("filters", [])
)
filters.append("metadata")
RuntimeConfig().logging["loggers"][k]["filters"] = filters
self._state = BasicState.INITIALIZED
except Exception as e:
LOG.exception("%s encountered exception during setup: %s", self, e)
self.stop(e) # type: ignore[attr-defined]
return ExperimentRunStartResponse(
sender_run_governor_id=request.receiver,
receiver_executor_id=request.sender,
experiment_run_id=request.experiment_run_id,
experiment_run=request.experiment_run,
successful=False,
error=e,
)
self._phase_launcher_task = asyncio.create_task(self._run_all_phases())
return ExperimentRunStartResponse(
sender_run_governor_id=request.receiver,
receiver_executor_id=request.sender,
experiment_run_id=request.experiment_run_id,
experiment_run=request.experiment_run,
successful=True,
error=None,
)
async def _run_all_phases(self):
try:
self._state = BasicState.RUNNING
n_phases = self.experiment_run.num_phases
self._simulation_controllers_down_event = asyncio.Event()
while self.current_phase < n_phases:
self._future_next_phase = (
asyncio.get_running_loop().create_future()
)
self._future_shutdown = (
asyncio.get_running_loop().create_future()
)
self._simulation_controllers_down_event.clear()
LOG.debug(
'Setting up phase %d/%d in experiment run "%s"...',
self.current_phase + 1, # +1 for display purposes
n_phases,
self.experiment_run.uid,
)
await self._setup_phase()
LOG.debug(
"%s sending start request(s) to simulation controllers...",
self,
)
ssrq = self._send_simulation_start_requests()
LOG.debug("%s sent simulation start requests: %s", self, ssrq)
LOG.debug("%s waiting for phase to end", self)
await self._future_next_phase
# Wait for all processes to really end:
LOG.debug(
"%s waiting for processes to end: %s",
self,
self._processes,
)
for p in self._processes:
await p.join()
self.current_phase += 1
LOG.info(
'Executed all phases in run "%s", shutting down.',
self.experiment_run.uid,
)
try:
await asyncio.wait_for(self._future_shutdown, timeout=15)
LOG.info("All conductors shutdown.")
except TimeoutError:
LOG.error(
"%s timed out while waiting for all processes to end. "
"Conductors still active: %s",
self,
self._agent_conductors + self._environment_conductors,
)
await self._request_phase_shutdown()
except Exception:
LOG.exception("Running all phases broke")
finally:
self.stop() # type: ignore[attr-defined]
async def _setup_run(self):
self.experiment_run.setup()
assert (
self.experiment_run.run_governor_termination_condition is not None
)
self._termination_condition = (
self.experiment_run.run_governor_termination_condition
)
async def _setup_phase(self):
self.current_episode_counts = LoggingDefaultDict(
int
) # Default value: 0
self._evaluate_every = self.experiment_run.get_evaluate_every(
self.current_phase
)
self._worker = self.experiment_run.get_worker(self.current_phase)
self.max_eval_episodes = (
int(
(
self.experiment_run.get_episodes(self.current_phase)
* self._worker
)
/ self._evaluate_every
)
if self._evaluate_every
else -1
)
if self._evaluate_every is not None:
if self._worker > self._evaluate_every:
LOG.warning(
"For deterministic evaluation the evaluation frequency "
"(evaluate_every) must be >= worker, otherwise the evaluation "
"is performed after at least 'evaluate_every' amount of "
"finished episodes, but the other workers will most likely already "
"advanced their episode execution in parallel."
)
self._eval_flow_control_change_lock = asyncio.Lock()
self._episode_counts_lock = asyncio.Lock()
self._evaluation_start_lock = asyncio.Lock()
self._has_stop_phase_signal = False
self._is_phase_shutting_down = False
self._agent_evaluation_lock = asyncio.Lock()
if self._future_agent_evaluation is not None:
LOG.error("Future agent evaluation leaks across phase!")
if self._pending_agent_evaluation_responses is not None:
LOG.error("Pending agent evaluation responses leak across phase!")
LOG.debug("Reset flow status mapping in _setup_phase")
self._check_reset_flow_status_mapping()
ps = await asyncio.gather(
self._start_environment_conductors(),
self._start_agent_conductors(),
self._start_simulation_controllers(),
) # Returns a nested list
self._processes = [p for pl in ps for p in pl] # Flatten "ps" list
LOG.debug(
"%s has processes this phase: %s",
self,
[p.name for p in self._processes],
)
@ESM.spawns
async def _start_simulation_controllers(
self,
optional_sc_list: Optional[List[SimulationController]] = None,
):
assert self.experiment_run is not None
simulation_controllers = optional_sc_list
if simulation_controllers is None:
simulation_controllers = list(
self.experiment_run.simulation_controllers(
self.current_phase
).values()
)
assert simulation_controllers is not None
# assert all(isinstance(sc, SimulationController) for sc in simulation_controllers)
sc_uids = [sc.uid for sc in simulation_controllers]
if optional_sc_list is None:
self._simulation_controllers = sc_uids
else:
self._simulation_controllers += sc_uids
LOG.debug(
"%s launching simulation controller processes: %s (uids: %s)",
self,
list(simulation_controllers),
sc_uids,
)
sc_processes = [
Process(
name=f"SimulationController-{sc.uid}",
target=spawn_wrapper,
args=(
RunGovernor._get_short_sc_proctitle(sc.uid),
RuntimeConfig().to_dict(),
sc.run, # type: ignore[attr-defined]
),
)
for sc in simulation_controllers
]
for p in sc_processes:
p.start()
return sc_processes
@staticmethod
def _get_short_sc_proctitle(uid):
s = "SimulationController-"
if "-EVALUATE" in uid:
s += "".join(uid.split("-")[:-1])[-6:] + "-EVALUATE"
else:
s += uid[-6:]
return s
@ESM.spawns
async def _start_environment_conductors(self):
environment_conductors = (
self.experiment_run.environment_conductors(self.current_phase)
).values()
LOG.debug(
"%s launching environment conductor processes: %s",
self,
list(environment_conductors),
)
ec_processes = [
Process(
name=f"EnvironmentConductor-{ec.uid}",
target=spawn_wrapper,
args=(
f"EnvironmentConductor-{ec.uid[-6:]}",
RuntimeConfig().to_dict(),
ec.run, # type: ignore[attr-defined]
),
)
for ec in environment_conductors
]
self._environment_conductors = [
ec.uid for ec in environment_conductors
]
for p in ec_processes:
p.start()
return ec_processes
@ESM.spawns
async def _start_agent_conductors(self):
agent_conductors = (
self.experiment_run.agent_conductors(self.current_phase)
).values()
LOG.debug(
"%s launching agent conductor processes: %s",
self,
list(agent_conductors),
)
ac_processes = [
Process(
name=f"AgentConductor-{ac.uid}",
target=spawn_wrapper,
args=(
f"AgentConductor-{ac.uid[-6:]}",
RuntimeConfig().to_dict(),
ac.run, # type: ignore[attr-defined]
),
)
for ac in agent_conductors
]
self._agent_conductors = [ac.uid for ac in agent_conductors]
for p in ac_processes:
p.start()
return ac_processes
def _create_evaluation_sc(self) -> Optional[SimulationController]:
sc = None
try:
assert self.experiment_run is not None
orig_simulation_controllers = list(
(
self.experiment_run.simulation_controllers(
self.current_phase
)
).values()
)
if len(orig_simulation_controllers) == 0:
LOG.error(
"Current phase %s must have at least one simulation controller, but none is present",
self.current_phase,
)
orig_sc: SimulationController = orig_simulation_controllers[0]
agent_configurations = copy.deepcopy(orig_sc._agent_configurations)
for agent_conf in agent_configurations.values():
# Load can be removed and there is no need for explicit load of brain, because the Muscle initially
# gets the model from the currently operating brain.
agent_conf.pop("load", None)
# Remove replay as there is no need for in evaluation
agent_conf.pop("replay", None)
sc = orig_sc.__class__(
agent_conductor_ids=orig_sc._agent_conductor_ids,
environment_conductor_ids=orig_sc._environment_conductor_ids,
agents=agent_configurations,
mode=Mode.EVALUATE,
termination_conditions=[
# Here, the MaxEpisodeTermCond is not needed, because this TermCond
# is only used in the SC. The max episodes (of 1) for the eval worker
# is handled by the method
# _handle_simulation_controller_termination_request
{
"name": "palaestrai.experiment:EnvironmentTerminationCondition",
"params": {},
}
],
)
sc.uid += "-EVALUATE"
except Exception:
LOG.exception(
"%s could not create and start a new simulation controller for evaluation",
self,
)
return sc
async def _start_evaluation(self):
scuid_dict = self._get_flow_status_mapping()
if scuid_dict is not None and any(
scuid.endswith("-EVALUATE")
for scuid in scuid_dict[SimulationFlowControl.PAUSE]
):
LOG.error(
"There are already evaluation workers active, will not start further evaluation!"
)
return
LOG.info(
"Starting evaluation after %d episodes "
'in phase %d of experiment run "%s".',
sum(self.current_episode_counts.values()),
self.current_phase + 1,
self.experiment_run.uid,
)
if self._future_simulation_controller_active is None:
self._future_simulation_controller_active = (
asyncio.get_running_loop().create_future()
)
try:
simulation_controllers = [
self._create_evaluation_sc() for _ in range(self._worker)
]
if any(sc is None for sc in simulation_controllers):
LOG.error(
"There were None evaluation simulation controllers created: %s",
simulation_controllers,
)
# The evaluation sc task does not have to be awaited for. If the simulation ends it is signalled by
# the SimulationControllerTerminationRequest which is reacted on
self._processes += await self._start_simulation_controllers(
simulation_controllers
)
_ = self._send_simulation_start_requests(
simulation_controllers=[
sc.uid for sc in simulation_controllers
],
phase_config={
"episodes": self.max_eval_episodes,
"worker": self._worker,
"mode": Mode.EVALUATE.name.lower(),
},
)
await self._future_simulation_controller_active
if (
self._future_simulation_controller_active.result()
!= "EVALUATE"
):
LOG.error(
"Continued not because all evaluation workers are active",
)
self._future_simulation_controller_active = None
except Exception as e:
LOG.error(
"%s could not create and start a new simulation controller for evaluation",
self,
)
raise e
@ESM.requests
def _send_simulation_start_requests(
self,
simulation_controllers: Optional[List[str]] = None,
phase_config: Optional[Dict[str, Any]] = None,
):
assert self.experiment_run is not None
if simulation_controllers is None:
simulation_controllers = self._simulation_controllers
LOG.debug(
"Send start requests to simulation controllers: %s",
simulation_controllers,
)
if phase_config is None:
phase_config = self.experiment_run.phase_configuration(
self.current_phase
)
if not any(
[sc_uid.endswith("-EVALUATE") for sc_uid in simulation_controllers]
):
self._simulation_controllers_active = set()
return [
SimulationStartRequest(
sender_run_governor_id=self.uid,
receiver_simulation_controller_id=sc_uid,
experiment_run_id=self.experiment_run.uid,
experiment_run_instance_id=self.experiment_run.instance_uid,
experiment_run_phase=self.current_phase,
experiment_run_phase_id=self.experiment_run.get_phase_name(
self.current_phase,
),
experiment_run_phase_configuration=phase_config,
# Every worker has its own SimulationController and therefore gets its respective episode
# For the continued training phase, the initial episode may already be advanced
# when using an intermittent evaluation phase.
# The _get_and_increase_eval_episode method is locked by the outer
# _evaluation_start_lock around the start_evaluation method.
episode=0,
)
for sc_uid in simulation_controllers
]
@ESM.on(SimulationStartResponse)
def _handle_simulation_start_response(
self, response: SimulationStartResponse
):
LOG.debug(
"%s got simulation start response from %s", self, response.sender
)
assert self._simulation_controllers_flow_status is not None
assert self._worker is not None
self._simulation_controllers_active |= {response.sender}
self._simulation_controllers_flow_status[response.sender] = (
SimulationFlowControl.CONTINUE
)
if (
self._future_simulation_controller_active is not None
and len(
{
sc_uid
for sc_uid in self._simulation_controllers_active
if sc_uid.endswith("-EVALUATE")
}
)
>= self._worker
):
self._future_simulation_controller_active.set_result("EVALUATE")
def _calc_worker_pause_offset(self):
episodes_elapsed = sum(
[
v
for k, v in self.current_episode_counts.items()
if not k.endswith("-EVALUATE")
]
)
return (
self._evaluate_every is not None
and (episodes_elapsed + min(self._worker, self._evaluate_every))
% self._evaluate_every
)
def _has_more_eval_episodes(self):
episodes_elapsed = {
v
for k, v in self.current_episode_counts.items()
if k.endswith("-EVALUATE")
}
episodes_elapsed = (
max(episodes_elapsed) if len(episodes_elapsed) > 0 else 0
)
return episodes_elapsed < self.max_eval_episodes
def _should_start_pausing_simulation_controllers(self):
LOG.debug(
"Relevant vars for determining to pause SCs: "
"evaluate_every: %s "
"_has_stop_phase_signal: %s "
"worker: %s "
"current_episode_counts: %s",
self._evaluate_every,
self._has_stop_phase_signal,
self._worker,
self.current_episode_counts,
)
scuid_dict = self._get_flow_status_mapping()
assert scuid_dict is not None
all_eval_paused = not any(
scuid
for scuid in scuid_dict[SimulationFlowControl.CONTINUE]
if scuid.endswith("-EVALUATE")
)
return (
self._evaluate_every is not None
and not self._has_stop_phase_signal
and all_eval_paused
and self._has_more_eval_episodes()
and self._calc_worker_pause_offset() == 0
)
def _should_pause_all_workers(self):
# Pause all workers that are "leftover" when worker > evaluate_every
# In this case they do not get paused "on the go" as the evaluation should
# be performed "in the middle" of a parallel episode of some workers
return (
self._evaluate_every is not None
and not self._has_stop_phase_signal
and self._worker > self._evaluate_every
and (self._evaluate_every - self._calc_worker_pause_offset()) == 1
)
def _compute_pausing_simulation_controller_signal(self):
if self._pausing_simulation_controllers is None:
LOG.error("Pausing simulation controller flag not set!")
if (
not self._pausing_simulation_controllers
and self._should_start_pausing_simulation_controllers()
):
LOG.debug(
"Starting to pause simulation controllers (current status: %s) (with the next worker) (%s)",
self._simulation_controllers_flow_status,
self,
)
self._pausing_simulation_controllers = True
# The current episode end of the worker only indicates
# to pause all following workers
return True
return False
async def _maybe_pause_simulation_controller(self, sc_uid: str):
if self._compute_pausing_simulation_controller_signal():
return
elif self._pausing_simulation_controllers:
if sc_uid.endswith("-EVALUATE"):
LOG.error(
"The evaluation worker should not be paused by the "
"_maybe_pause_simulation_controller method! "
"Inconsistencies may occur later on!"
)
if self._should_pause_all_workers():
scuid_dict = self._get_flow_status_mapping()
assert scuid_dict is not None
to_be_paused_scuids = scuid_dict[
SimulationFlowControl.CONTINUE
]
if sc_uid not in to_be_paused_scuids:
LOG.error(
"Should pause all left over workers on termination requestion of "
"simulation controller %s, but itself is computed not to be paused",
sc_uid,
)
else:
to_be_paused_scuids = {sc_uid}
await self._change_sc_flow_control(
to_be_paused_scuids, SimulationFlowControl.PAUSE
)
def _should_start_evaluation(self, mode: Mode):
assert self._simulation_controllers_flow_status is not None
all_workers_paused = all(
status == SimulationFlowControl.PAUSE
for scuid, status in self._simulation_controllers_flow_status.items()
if not scuid.endswith("-EVALUATE")
)
return (
not self._has_stop_phase_signal
and not self._is_phase_shutting_down
and mode != Mode.EVALUATE
and all_workers_paused
)
async def _change_sc_flow_control(
self, scuids: Set[str], flow: SimulationFlowControl
):
assert self._eval_flow_control_change_lock is not None
assert self.experiment_run is not None
LOG.debug(
f"Change flow of simulation controllers (%s) to %s "
f'in phase %d of experiment run "%s".',
scuids,
flow,
self.current_phase + 1,
self.experiment_run.uid,
)
async with self._eval_flow_control_change_lock:
if self._future_simulation_controllers_flow_change is not None:
LOG.error(
"%s leaks future (%s) for flow change from other changing request",
self,
self._future_simulation_controllers_flow_change,
)
flow_status_mapping = self._get_flow_status_mapping()
assert flow_status_mapping is not None
continued_scuids = flow_status_mapping[
SimulationFlowControl.CONTINUE
]
continued_scuids = continued_scuids.union(scuids)
eval_scuids, normal_scuids = {
scuid
for scuid in continued_scuids
if scuid.endswith("-EVALUATE")
}, {
scuid
for scuid in continued_scuids
if not scuid.endswith("-EVALUATE")
}
if len(eval_scuids) > 0 and len(normal_scuids) > 0:
LOG.warning(
"%s normal (%s) and evaluation (%s) worker will be run (continued) in parallel, but they usually should not.",
self,
normal_scuids,
eval_scuids,
)
self._future_simulation_controllers_flow_change = (
asyncio.get_running_loop().create_future()
)
_ = self._request_simulation_flow_change(flow, list(scuids))
await self._future_simulation_controllers_flow_change
self._future_simulation_controllers_flow_change = None
async def _pause_eval_and_continue_worker_simulation_controllers(
self, eval_sc_uid: str
):
if not eval_sc_uid.endswith("-EVALUATE"):
LOG.error(
"Calling _pause_eval_and_continue_worker_simulation_controllers not with an evaluation worker!"
)
scuid_dict = self._get_flow_status_mapping()
if scuid_dict is None:
LOG.error(
"No paused simulation controllers found that could be continued!"
)
return
if eval_sc_uid not in scuid_dict[SimulationFlowControl.CONTINUE]:
LOG.error(
"Evaluation worker with uid %s should be paused, but it is not CONTINUEd",
eval_sc_uid,
)
if (
len(scuid_dict[SimulationFlowControl.CONTINUE]) == 0
or len(scuid_dict[SimulationFlowControl.PAUSE]) == 0
):
LOG.error(
"No actual flip can be performed, "
"because active (%s) or "
"paused simulations controller list (%s) "
"is empty.",
scuid_dict[SimulationFlowControl.CONTINUE],
scuid_dict[SimulationFlowControl.PAUSE],
)
all_evaluation_workers_paused = (
len(scuid_dict[SimulationFlowControl.CONTINUE] - {eval_sc_uid})
== 0
)
# CAUTION: DO NOT BLINDLY REMOVE, THIS MAY BE RELEVANT FOR FUTURE USE
# if all_evaluation_workers_paused:
# await self._agent_evaluation()
if not self._has_more_eval_episodes():
LOG.info(
"No more evaluation episodes left, going to continue evaluation worker "
"to shut them down."
)
return
await self._change_sc_flow_control(
{eval_sc_uid},
SimulationFlowControl.PAUSE,
)
if all_evaluation_workers_paused:
self._compute_pausing_simulation_controller_signal()
await self._change_sc_flow_control(
{
scuid
for scuid in scuid_dict[SimulationFlowControl.PAUSE]
if not scuid.endswith("-EVALUATE")
},
SimulationFlowControl.CONTINUE,
)
async def _continue_eval_worker_simulation_controllers(self):
scuid_dict = self._get_flow_status_mapping()
if scuid_dict is None:
LOG.error(
"No paused simulation controllers found that could be continued!"
)
return
if len(scuid_dict[SimulationFlowControl.CONTINUE]) != 0:
LOG.error(
"There are still simulation controllers running: %s!"
"Starting the evaluation worker will lead to inconsistent states!",
scuid_dict[SimulationFlowControl.CONTINUE],
)
eval_paused_scuids = {
scuid
for scuid in scuid_dict[SimulationFlowControl.PAUSE]
if scuid.endswith("-EVALUATE")
}
if len(eval_paused_scuids) != self._worker:
LOG.error(
"Not exactly worker-amount many paused evaluation worker found that could be continued, but %s!",
eval_paused_scuids,
)
await self._change_sc_flow_control(
eval_paused_scuids,
SimulationFlowControl.CONTINUE,
)
async def _maybe_continue_simulation_controllers(
self, ended_sc_uid: str, process_pid: int
):
try:
scuid_dict = self._get_flow_status_mapping()
if scuid_dict is not None:
if (
len(
scuid_dict[SimulationFlowControl.CONTINUE]
- {ended_sc_uid}
)
== 0
):
paused_sc_uids = scuid_dict[
SimulationFlowControl.PAUSE
] - {ended_sc_uid}
if (
all(
scuid.endswith("-EVALUATE")
for scuid in paused_sc_uids
)
or not any(
scuid.endswith("-EVALUATE")
for scuid in paused_sc_uids
)
) and len(paused_sc_uids) > 0:
await self._change_sc_flow_control(
paused_sc_uids,
SimulationFlowControl.CONTINUE,
)
assert self._simulation_controllers_flow_status is not None
# CAUTION: DO NOT BLINDLY REMOVE, THIS MAY BE RELEVANT FOR FUTURE USE
# if (
# self._evaluate_every is not None
# and len(
# set(self._simulation_controllers_flow_status.keys())
# - {ended_sc_uid}
# )
# == 0
# ):
# await self._agent_evaluation()
except Exception:
LOG.exception(
"Maybe-continuing simulation controllers failed after SC %s ended",
ended_sc_uid,
)
finally:
if ended_sc_uid not in self._simulation_controllers_active:
LOG.error(
"The evaluation simulation controller %s is done and should logically be shutdown, "
"but it is no longer listed as active (only active simulation controllers: %s).",
ended_sc_uid,
self._simulation_controllers_active,
)
self._shutdown_simulation_controller(ended_sc_uid, process_pid)
self._simulation_controllers_continuation_task = None
@ESM.on(SimulationControllerTerminationRequest)
async def _handle_simulation_controller_termination_request(
self, request: SimulationControllerTerminationRequest
):
LOG.debug(
"%s, workers: %s", request, self._simulation_controllers_active
)
assert self.experiment_run is not None
assert self._termination_condition is not None
if request.sender not in self._simulation_controllers_active:
LOG.error(
"Sender of termination request %s is no longer active, but only: %s",
request.sender,
self._simulation_controllers_active,
)
# The async lock may technically not be needed,
# but should indicate that the reading and writing
# from and to the current_episode_counts have to be
# transactional across the request handling tasks
assert self._episode_counts_lock is not None
async with self._episode_counts_lock:
current_episode = self.current_episode_counts[request.sender]
# The phase_flow_control is assumed to be called only once by the RunGovernor
flow, _ = self._termination_condition.phase_flow_control(
self, request
)
if flow.value > SimulationFlowControl.CONTINUE.value:
self.current_episode_counts[request.sender] += 1
LOG.debug(
"Increased episode counts of %s (dict: %s, flow: %s)",
request.sender,
self.current_episode_counts,
flow,
)
# CONTINUE is a valid flow control result. Return early here:
if flow == SimulationFlowControl.CONTINUE:
return SimulationControllerTerminationResponse(
sender_run_governor_id=request.receiver,
receiver_simulation_controller_id=request.sender,
experiment_run_instance_id=self.experiment_run.instance_uid,
experiment_run_id=request.experiment_run_id,
experiment_run_phase=self.current_phase,
# Every worker has its own SimulationController
# and, therefore, gets its respective episode
episode=current_episode,
restart=False,
complete_shutdown=False,
flow_control=flow,
)
if request.mode == Mode.EVALUATE:
# (2.) Pause current eval worker and continue normal workers when all eval workers are paused
await self._pause_eval_and_continue_worker_simulation_controllers(
request.sender
)
else:
await self._maybe_pause_simulation_controller(request.sender)
if self._should_start_evaluation(request.mode):
# Already paused all simulation controllers, because now evaluating.
# So further simulation controllers needs to be paused and the signal,
# set in _maybe_pause_simulation_controller, can be turned down
self._pausing_simulation_controllers = False
LOG.info(
"Starting evaluation %s "
"in phase %s of experiment run %s",
None, # TODO: Get and check the current eval episode
self.current_phase + 1,
self.experiment_run.uid,
)
if any(
scuid.endswith("-EVALUATE")
for scuid in self._simulation_controllers_active
):
# (3.) If next evaluation should be performed, reuse already existing (active),
# but paused evaluation workers. There cannot be used a unified flipping method
# because here all workers are paused, other than at (2.), where the eval worker
# is at CONTINUE
await self._continue_eval_worker_simulation_controllers()
else:
# (1.) Start eval worker processes initially
assert self._evaluation_start_lock is not None
async with self._evaluation_start_lock:
await self._start_evaluation()
if flow.value <= SimulationFlowControl.RESTART.value:
# Restart/soft-reset of this particular worker:
LOG.info(
'Restarting simulation worker "%s" '
"in phase %d for episode %d "
'of experiment run "%s"',
request.sender,
self.current_phase + 1,
current_episode + 1,
self.experiment_run.uid,
)
return SimulationControllerTerminationResponse(
sender_run_governor_id=request.receiver,
receiver_simulation_controller_id=request.sender,
experiment_run_instance_id=self.experiment_run.instance_uid,
experiment_run_id=request.experiment_run_id,
experiment_run_phase=self.current_phase,
# Every worker has its own SimulationController and therefore gets its respective episode
episode=current_episode,
restart=True,
complete_shutdown=False,
flow_control=flow,
)
LOG.info(
'Signalling simulation worker "%s" to shut down '
'for phase %d in experiment run "%s"',
request.sender,
self.current_phase + 1, # +1 for display purposes only.
self.experiment_run.uid,
)
# Potentially shut down all workers, so sending
# SimulationShutdownRequest to all running simulation controllers,
# also the SimController that requests the termination, because
# itself did not or has not been shutdown yet
if flow == SimulationFlowControl.STOP_PHASE:
# Create task to ask next step, shutdown all conductors.
# Its only a matter of time until the last one goes down...
# We start this task here, but we try to wait for the
# SC process to actually end using a futureā¦
# The future is set in the SIGCHLD event handler.
LOG.info(
"Stopping phase %d of experiment run %s",
self.current_phase,
self.experiment_run.uid,
)
self._has_stop_phase_signal = True
return SimulationControllerTerminationResponse(
sender_run_governor_id=request.receiver,
receiver_simulation_controller_id=request.sender,
experiment_run_instance_id=self.experiment_run.instance_uid,
experiment_run_id=request.experiment_run_id,
experiment_run_phase=self.current_phase,
# Every worker has its own SimulationController and therefore gets its respective episode
episode=current_episode,
restart=False,
complete_shutdown=True,
flow_control=flow,
)
@ESM.requests
def _request_simulation_flow_change(
self,
simulation_flow: SimulationFlowControl,
simulation_controllers: List[str],
):
assert self.experiment_run is not None
self._simulation_controllers_pending_response = set(
simulation_controllers
)
return [
SimulationFlowChangeRequest(
sender_run_governor_id=self.uid,
receiver_simulation_controller_id=sc_uid,
experiment_run_id=self.experiment_run.uid,
experiment_run_instance_id=self.experiment_run.instance_uid,
experiment_run_phase=self.current_phase,
simulation_flow=simulation_flow,
interrupt=True,
overwrite_flow=False,
)
for sc_uid in simulation_controllers
]
@ESM.on(SimulationFlowChangeResponse)
async def _handle_simulation_flow_change_response(
self, response: SimulationFlowChangeResponse
):
assert (
self._future_simulation_controllers_flow_change is not None
and not self._future_simulation_controllers_flow_change.done()
), "Future for SC flow control change is not present!"
assert self._simulation_controllers_flow_status is not None
if (
response.requested_flow
== self._simulation_controllers_flow_status.get(
response.sender, None
)
):
LOG.error(
"Flow (%s) was not actually changed for simulation controller: %s",
response.requested_flow,
response.sender,
)
if (
response.requested_flow != SimulationFlowControl.CONTINUE
and response.requested_flow != SimulationFlowControl.PAUSE
):
LOG.error(
"Changed flow for %s is not CONTINUE and not PAUSE, but %s",
response.sender,
response.requested_flow,
)
self._simulation_controllers_flow_status[response.sender] = (
response.requested_flow
)
LOG.debug(
"Noticed SimulationFlowChangeResponse (%s)."
"_simulation_controllers_flow_status after processing: %s",
response,
self._simulation_controllers_flow_status,
)
if (
response.sender
not in self._simulation_controllers_pending_response
):
LOG.error(
"%s sent response for %s simulation flow change request, "
"but is internally not pending for response!",
response.sender,
response.requested_flow,
)
self._simulation_controllers_pending_response -= {response.sender}
if len(self._simulation_controllers_pending_response) == 0:
self._future_simulation_controllers_flow_change.set_result(True)
def _get_flow_status_mapping(
self,
) -> Optional[Dict[SimulationFlowControl, Set[str]]]:
if self._simulation_controllers_flow_status is None:
LOG.error("Flow status mapping is not present!")
return None
return {
SimulationFlowControl.PAUSE: {
scuid
for scuid, status in self._simulation_controllers_flow_status.items()
if status == SimulationFlowControl.PAUSE
},
SimulationFlowControl.CONTINUE: {
scuid
for scuid, status in self._simulation_controllers_flow_status.items()
if status == SimulationFlowControl.CONTINUE
},
}
def _check_reset_flow_status_mapping(self):
if (
self._simulation_controllers_flow_status is not None
and len(self._simulation_controllers_flow_status) > 0
):
LOG.error(
"Current flow status mapping is not empty, but leaking: %s",
self._simulation_controllers_flow_status,
)
self._simulation_controllers_flow_status = {}
self._pausing_simulation_controllers = (
self._should_start_pausing_simulation_controllers()
)
async def _agent_evaluation(self):
# TODO: Currently with DB query over config may lead to race condition,
# because data in DB is only eventually accessible; without
assert self._agent_evaluation_lock
async with self._agent_evaluation_lock:
if self._future_agent_evaluation is not None:
LOG.error("Future agent evaluation is already present!")
self._future_agent_evaluation = (
asyncio.get_running_loop().create_future()
)
self._request_agent_evaluation()
await self._future_agent_evaluation
self._future_agent_evaluation = None
@ESM.requests
def _request_agent_evaluation(self):
evaluation_episodes = {
v
for k, v in self.current_episode_counts.items()
if k.endswith("-EVALUATE")
}
if len(evaluation_episodes) != 1:
LOG.error(
"Different evaluation episode counters: %s",
self.current_episode_counts,
)
if self._pending_agent_evaluation_responses is not None:
LOG.error("Pending agent evaluation response is already present!")
self._pending_agent_evaluation_responses = set(self._agent_conductors)
return [
AgentEvaluationRequest(
sender_run_governor=self.uid,
receiver_agent_conductor=agent_conductor,
experiment_run_uid=self.experiment_run.uid,
experiment_run_instance_uid=self.experiment_run.instance_uid,
experiment_run_phase=self.current_phase,
# The evaluation episodes are the amount of episodes already computed,
# so subtract one to get the index
evaluation_episode=max(evaluation_episodes) - 1,
)
for agent_conductor in self._agent_conductors
]
@ESM.on(AgentEvaluationResponse)
def _handle_agent_evaluation_response(
self, response: AgentEvaluationResponse
):
assert self._future_agent_evaluation is not None
assert self._pending_agent_evaluation_responses is not None
self._pending_agent_evaluation_responses -= {response.sender}
if len(self._pending_agent_evaluation_responses) == 0:
self._pending_agent_evaluation_responses = None
self._future_agent_evaluation.set_result(True)
async def _request_phase_shutdown(self):
if not self._has_stop_phase_signal:
LOG.error(
"No STOP_PHASE signal was sent, but requesting to shut phase down"
)
if self._is_phase_shutting_down:
return
self._is_phase_shutting_down = True
if self._future_simulation_controller_active is not None:
LOG.log(
palaestrai.logging.ASYNCIO_LOG_LEVEL,
"%s waits on future_simulation_controller_active",
self,
)
await self._future_simulation_controller_active
LOG.info("Send requests for shutting down simulation controllers.")
if len(self._simulation_controllers_active) > 0:
_ = self._send_simulation_shutdown_requests()
await self._simulation_controllers_down_event.wait()
LOG.info("Send requests for shutting down conductors.")
# Sync with the agent evaluation lock to make sure that the agent evaluation
# is still performed before shutting down the conductors,
# if the evaluation is still running when the shutdown is requested.
async with self._agent_evaluation_lock:
if len(self._agent_conductors) > 0:
_ = self._send_agent_conductor_shutdown_requests()
if len(self._environment_conductors) > 0:
_ = self._send_environment_conductor_shutdown_requests()
@ESM.requests
def _send_agent_conductor_shutdown_requests(self):
LOG.debug(
"%s requesting shut down of %s", self, self._agent_conductors
)
return [
ShutdownRequest(
sender=self.uid,
receiver=acuid,
experiment_run_id=self.experiment_run.uid,
experiment_run_instance_id=self.experiment_run.instance_uid,
experiment_run_phase=self.current_phase,
)
for acuid in self._agent_conductors
]
@ESM.requests
def _send_environment_conductor_shutdown_requests(self):
LOG.debug(
"%s requesting shutdown of %s", self, self._environment_conductors
)
return [
ShutdownRequest(
sender=self.uid,
receiver=ecuid,
experiment_run_id=self.experiment_run.uid,
experiment_run_instance_id=self.experiment_run.instance_uid,
experiment_run_phase=self.current_phase,
)
for ecuid in self._environment_conductors
]
@ESM.on(ShutdownResponse)
def _handle_shutdown_response(self, response: ShutdownResponse):
"""Tracks shutdown acknowledgments and completes shutdown when all stop"""
self._agent_conductors = [
acuid
for acuid in self._agent_conductors
if not acuid == response.sender
]
self._environment_conductors = [
ecuid
for ecuid in self._environment_conductors
if not ecuid == response.sender
]
LOG.debug(
"%s got %s, conductors still up: %s",
self,
response,
self._agent_conductors + self._environment_conductors,
)
if (
len(self._agent_conductors) + len(self._environment_conductors)
== 0
):
assert self._future_next_phase is not None
assert self._future_shutdown is not None
try:
self._future_next_phase.set_result(True)
self._future_shutdown.set_result(True)
except InvalidStateError:
pass # Sometimes when hitting Ctrl-C, there's a double stop.
@ESM.requests
def _send_simulation_shutdown_requests(self):
LOG.debug(
"%s sending SimulationShutdownRequest(s) to %s",
self,
self._simulation_controllers_active,
)
return [
SimulationShutdownRequest(
sender=self.uid,
receiver=sc_uid,
experiment_run_id=self.experiment_run.uid,
experiment_run_instance_id=self.experiment_run.instance_uid,
experiment_run_phase=self.current_phase,
)
for sc_uid in self._simulation_controllers_active
]
def _shutdown_simulation_controller(self, sc_uid: str, process_pid: int):
LOG.info("Shutting down simulation controller %s", sc_uid)
self._simulation_controllers_active -= {sc_uid}
if (
self._simulation_controllers_flow_status is not None
and sc_uid in self._simulation_controllers_flow_status
):
del self._simulation_controllers_flow_status[sc_uid]
if not any(
scuid.endswith("-EVALUATE")
for scuid in self._simulation_controllers_active
):
LOG.debug("No evaluation workers are active")
self._remove_simulation_controller_process(process_pid)
# The check on the evaluation_start_lock is need when the normal SCs are done but still
# evaluations has to be performed
assert self._evaluation_start_lock is not None
if (
len(self._simulation_controllers_active) != 0
or self._evaluation_start_lock.locked()
):
return
LOG.info("No simulation controllers are active any longer")
assert self._simulation_controllers_down_event is not None
self._simulation_controllers_down_event.set()
self._shutdown_conductors_task = asyncio.create_task(
self._request_phase_shutdown()
)
def _remove_simulation_controller_process(self, process_pid: int):
self._processes = [p for p in self._processes if p.pid != process_pid]
if (
len(
[
p
for p in self._processes
if p.name.startswith("SimulationController-")
]
)
== 0
):
if len(self._simulation_controllers_active) > 0:
LOG.error(
"All simulation controller processes end, "
"but some are still logically active (%s)",
self._simulation_controllers_active,
)
self._simulation_controllers_active = set()
@ESM.on(SimulationShutdownResponse)
def _handle_simulation_shutdown_response(
self, response: SimulationShutdownResponse
):
LOG.debug("Got %s", response)
@ESM.on(ExperimentRunShutdownRequest)
async def _handle_shutdown_request(
self, request: ExperimentRunShutdownRequest
):
LOG.info(
'Shutdown of experiment run "%s" requested',
self.experiment_run.uid if self.experiment_run else "(no run)",
)
await self._request_phase_shutdown()
if self._phase_launcher_task is not None:
self._phase_launcher_task.cancel()
self.stop() # type: ignore[attr-defined]
return ExperimentRunShutdownResponse(
sender_run_governor_id=self.uid,
receiver_executor_id=request.sender,
experiment_run_id=request.experiment_run_id,
successful=True,
error=None,
)
@ESM.on(ErrorIndicator)
def _raise_error_indicator(self, error: ErrorIndicator):
self._state = BasicState.ERROR
LOG.critical(
"%s received error from %s: %s",
self,
error.sender,
error.error_message,
)
# TODO: Shutdown everything
if error.exception is not None:
raise error.exception
raise RuntimeError(error.error_message)
@ESM.on(signal.SIGCHLD)
def _handle_child(
self, process: Union[aiomultiprocess.Process, multiprocessing.Process]
):
LOG.debug("%s saw process %s end.", self, process.name)
if process.exitcode != 0:
self._state = BasicState.ERROR
LOG.critical(
"Subprocess %s exited prematurely with rc %s). "
"Cannot continue simulation.",
process.name,
process.exitcode,
)
self.stop( # type: ignore[attr-defined]
RuntimeError(
f"Subprocess {process.name} ended prematurely "
f"with rc {process.exitcode}"
)
)
if process.name.startswith("SimulationController-"):
simulation_controller_name = process.name.removeprefix(
"SimulationController-"
)
LOG.debug(
"SimulationController %s has signalled that it is has finished.",
simulation_controller_name,
)
assert isinstance(process.pid, int)
self._simulation_controllers_continuation_task = (
asyncio.create_task(
self._maybe_continue_simulation_controllers(
simulation_controller_name, process.pid
)
)
)
def __str__(self):
return (
f"RunGovernor(uid={self.uid}, state={self._state.name}, "
f"experiment_run_uid="
f"{self.experiment_run.uid if self.experiment_run is not None else '(None)'}, "
f"phase={self.current_phase})"
)