Source code for palaestrai.experiment.run_governor

from __future__ import annotations

from asyncio import InvalidStateError

import copy
from typing import TYPE_CHECKING, List, Dict, Set, Optional, Union, Any

import uuid
import signal
import logging
import asyncio

from aiomultiprocess import Process

import palaestrai.logging
from palaestrai.core.protocol import (
    ExperimentRunStartRequest,
    ExperimentRunStartResponse,
    SimulationStartRequest,
    SimulationStartResponse,
    SimulationControllerTerminationRequest,
    SimulationControllerTerminationResponse,
    ErrorIndicator,
    SimulationShutdownRequest,
    SimulationShutdownResponse,
    ExperimentRunShutdownRequest,
    ExperimentRunShutdownResponse,
    ShutdownRequest,
    ShutdownResponse,
)
from palaestrai.core.protocol.agent_evaluation_req import (
    AgentEvaluationRequest,
)
from palaestrai.core.protocol.agent_evaluation_rsp import (
    AgentEvaluationResponse,
)
from palaestrai.core.protocol.simulation_flow_change_req import (
    SimulationFlowChangeRequest,
)
from palaestrai.core.protocol.simulation_flow_change_rsp import (
    SimulationFlowChangeResponse,
)
from palaestrai.logging.utils import LoggingDefaultDict
from palaestrai.types import SimulationFlowControl, Mode
from palaestrai.core import EventStateMachine as ESM
from palaestrai.core import BasicState, RuntimeConfig
from palaestrai.util import spawn_wrapper, MetadataLogFilter
from palaestrai.util.exception import ExperimentAlreadyRunningError

if TYPE_CHECKING:
    import aiomultiprocess
    import multiprocessing
    from palaestrai.experiment import TerminationCondition
    from palaestrai.experiment.experiment_run import ExperimentRun
    from palaestrai.simulation import SimulationController

LOG = logging.getLogger(__name__)


[docs] @ESM.monitor(is_mdp_worker=True) class RunGovernor: """ This class implements the Run-Governor. Upon receiving requests from the executor, a RunGovernor instance handles a single experiment run by starting it, initialize the simulation controllers, the environment and the agent conductors, and, finally, shutting the experiment run down. The RunGovernor is implemented as state machine and this class provides the context for the distinct state classes. A freshly initialized RunGovernor waits in the state PRISTINE until the run method is called by the executor. See the distinct state classes for more information. Parameters ---------- uid : str The universally unique ID that identifies this run governor Attributes ---------- uid: str The UUID of this RunGovernor termination_condition: :class:`.TerminationCondition` A reference to the TerminationCondition instance. run_broker: :class:`.MajorDomoBroker` The broker for the communication with the simulation controller, the agents, and the environments. experiment_run_id: str The UUID of the current experiment run. tasks: List[aiomultiprocess.Process] A list of tasks the RunGovernor has started and that it has to shut down in the end. worker: :class:`.MajorDomoWorker` The major domo worker for handling incoming requests client: :class:`.MajorDomoClient` The major domo client for sending requests to other workers. shutdown: bool The major kill switch of the RunGovernor. Setting this to false will stop the RunGovernor after the current state. state: :class:`.RunGovernorState` Holds the current state instance. The first state is PRISTINE. errors: List[Exception] A list that is used to collect errors raised in the states. """ def __init__(self, uid: Optional[str] = None): self.uid = uid if uid else "RunGovernor-%s" % str(uuid.uuid4()) self._state = BasicState.PRISTINE # Experiment control: self.current_phase: int = 0 self.current_episode_counts: Dict[str, int] = {} self.experiment_run: Optional[ExperimentRun] = None self._termination_condition: Optional[TerminationCondition] = None self._evaluate_every: Optional[int] = None self.max_eval_episodes: Optional[int] = None self._worker: Optional[int] = None self._has_stop_phase_signal: Optional[bool] = None self._is_phase_shutting_down: Optional[bool] = None self._simulation_controllers_flow_status: Optional[ Dict[str, SimulationFlowControl] ] = None self._pausing_simulation_controllers: Optional[bool] = None # Receiver information: self._simulation_controllers: List[str] = [] self._environment_conductors: List[str] = [] self._agent_conductors: List[str] = [] # Receiver synchronization: self._simulation_controllers_active: Set[str] = set() self._simulation_controllers_pending_response: Set[str] = set() self._future_simulation_controller_active: Optional[asyncio.Future] = ( None ) self._simulation_controllers_down_event: Optional[asyncio.Event] = None self._future_simulation_controllers_flow_change: Optional[ asyncio.Future ] = None self._future_shutdown: Optional[asyncio.Future] = None self._future_next_phase: Optional[asyncio.Future] = None self._phase_launcher_task: Optional[asyncio.Task] = None self._shutdown_conductors_task: Optional[asyncio.Task] = None # Evaluation self._evaluation_start_lock: Optional[asyncio.Lock] = None self._eval_flow_control_change_lock: Optional[asyncio.Lock] = None self._episode_counts_lock: Optional[asyncio.Lock] = None self._simulation_controllers_continuation_task: Optional[ asyncio.Task ] = None self._agent_evaluation_lock: Optional[asyncio.Lock] = None self._future_agent_evaluation: Optional[asyncio.Future] = None self._pending_agent_evaluation_responses: Optional[Set[str]] = None # Subprocess management: self._processes: List[aiomultiprocess.Process] = []
[docs] def setup(self): self._state = BasicState.PRISTINE self.mdp_service = self.uid # type: ignore[attr-defined] LOG.debug("%s ready.", self) self._state = BasicState.INITIALIZING
@ESM.on(ExperimentRunStartRequest) async def _handle_experiment_run_start_request( self, request: ExperimentRunStartRequest ): if self.experiment_run is not None: LOG.warning( '%s got request to start experiment run "%s", ' 'but experiment run "%s" is already running. ' "Reporting error and continueing. Elect somebody " "else instead!", self, self.experiment_run.uid, request.experiment_run.uid, ) self.stop() # type: ignore[attr-defined] return ExperimentRunStartResponse( sender_run_governor_id=request.receiver, receiver_executor_id=request.sender, experiment_run_id=request.experiment_run_id, experiment_run=request.experiment_run, successful=False, error=ExperimentAlreadyRunningError(self.experiment_run), ) self.current_phase = 0 self.experiment_run = request.experiment_run LOG.debug("%s setting up %s", self, self.experiment_run) try: await self._setup_run() LOG.addFilter( MetadataLogFilter(experiment_run_uid=self.experiment_run.uid) ) log_filters = RuntimeConfig().logging.get("filters", {}) log_filters["metadata"] = { "()": "palaestrai.util.metadata_logfilter.MetadataLogFilter", "experiment_run_uid": self.experiment_run.uid, } RuntimeConfig().logging["filters"] = log_filters for k in RuntimeConfig().logging["loggers"].keys(): filters = ( RuntimeConfig().logging["loggers"][k].get("filters", []) ) filters.append("metadata") RuntimeConfig().logging["loggers"][k]["filters"] = filters self._state = BasicState.INITIALIZED except Exception as e: LOG.exception("%s encountered exception during setup: %s", self, e) self.stop(e) # type: ignore[attr-defined] return ExperimentRunStartResponse( sender_run_governor_id=request.receiver, receiver_executor_id=request.sender, experiment_run_id=request.experiment_run_id, experiment_run=request.experiment_run, successful=False, error=e, ) self._phase_launcher_task = asyncio.create_task(self._run_all_phases()) return ExperimentRunStartResponse( sender_run_governor_id=request.receiver, receiver_executor_id=request.sender, experiment_run_id=request.experiment_run_id, experiment_run=request.experiment_run, successful=True, error=None, ) async def _run_all_phases(self): try: self._state = BasicState.RUNNING n_phases = self.experiment_run.num_phases self._simulation_controllers_down_event = asyncio.Event() while self.current_phase < n_phases: self._future_next_phase = ( asyncio.get_running_loop().create_future() ) self._future_shutdown = ( asyncio.get_running_loop().create_future() ) self._simulation_controllers_down_event.clear() LOG.debug( 'Setting up phase %d/%d in experiment run "%s"...', self.current_phase + 1, # +1 for display purposes n_phases, self.experiment_run.uid, ) await self._setup_phase() LOG.debug( "%s sending start request(s) to simulation controllers...", self, ) ssrq = self._send_simulation_start_requests() LOG.debug("%s sent simulation start requests: %s", self, ssrq) LOG.debug("%s waiting for phase to end", self) await self._future_next_phase # Wait for all processes to really end: LOG.debug( "%s waiting for processes to end: %s", self, self._processes, ) for p in self._processes: await p.join() self.current_phase += 1 LOG.info( 'Executed all phases in run "%s", shutting down.', self.experiment_run.uid, ) try: await asyncio.wait_for(self._future_shutdown, timeout=15) LOG.info("All conductors shutdown.") except TimeoutError: LOG.error( "%s timed out while waiting for all processes to end. " "Conductors still active: %s", self, self._agent_conductors + self._environment_conductors, ) await self._request_phase_shutdown() except Exception: LOG.exception("Running all phases broke") finally: self.stop() # type: ignore[attr-defined] async def _setup_run(self): self.experiment_run.setup() assert ( self.experiment_run.run_governor_termination_condition is not None ) self._termination_condition = ( self.experiment_run.run_governor_termination_condition ) async def _setup_phase(self): self.current_episode_counts = LoggingDefaultDict( int ) # Default value: 0 self._evaluate_every = self.experiment_run.get_evaluate_every( self.current_phase ) self._worker = self.experiment_run.get_worker(self.current_phase) self.max_eval_episodes = ( int( ( self.experiment_run.get_episodes(self.current_phase) * self._worker ) / self._evaluate_every ) if self._evaluate_every else -1 ) if self._evaluate_every is not None: if self._worker > self._evaluate_every: LOG.warning( "For deterministic evaluation the evaluation frequency " "(evaluate_every) must be >= worker, otherwise the evaluation " "is performed after at least 'evaluate_every' amount of " "finished episodes, but the other workers will most likely already " "advanced their episode execution in parallel." ) self._eval_flow_control_change_lock = asyncio.Lock() self._episode_counts_lock = asyncio.Lock() self._evaluation_start_lock = asyncio.Lock() self._has_stop_phase_signal = False self._is_phase_shutting_down = False self._agent_evaluation_lock = asyncio.Lock() if self._future_agent_evaluation is not None: LOG.error("Future agent evaluation leaks across phase!") if self._pending_agent_evaluation_responses is not None: LOG.error("Pending agent evaluation responses leak across phase!") LOG.debug("Reset flow status mapping in _setup_phase") self._check_reset_flow_status_mapping() ps = await asyncio.gather( self._start_environment_conductors(), self._start_agent_conductors(), self._start_simulation_controllers(), ) # Returns a nested list self._processes = [p for pl in ps for p in pl] # Flatten "ps" list LOG.debug( "%s has processes this phase: %s", self, [p.name for p in self._processes], ) @ESM.spawns async def _start_simulation_controllers( self, optional_sc_list: Optional[List[SimulationController]] = None, ): assert self.experiment_run is not None simulation_controllers = optional_sc_list if simulation_controllers is None: simulation_controllers = list( self.experiment_run.simulation_controllers( self.current_phase ).values() ) assert simulation_controllers is not None # assert all(isinstance(sc, SimulationController) for sc in simulation_controllers) sc_uids = [sc.uid for sc in simulation_controllers] if optional_sc_list is None: self._simulation_controllers = sc_uids else: self._simulation_controllers += sc_uids LOG.debug( "%s launching simulation controller processes: %s (uids: %s)", self, list(simulation_controllers), sc_uids, ) sc_processes = [ Process( name=f"SimulationController-{sc.uid}", target=spawn_wrapper, args=( RunGovernor._get_short_sc_proctitle(sc.uid), RuntimeConfig().to_dict(), sc.run, # type: ignore[attr-defined] ), ) for sc in simulation_controllers ] for p in sc_processes: p.start() return sc_processes @staticmethod def _get_short_sc_proctitle(uid): s = "SimulationController-" if "-EVALUATE" in uid: s += "".join(uid.split("-")[:-1])[-6:] + "-EVALUATE" else: s += uid[-6:] return s @ESM.spawns async def _start_environment_conductors(self): environment_conductors = ( self.experiment_run.environment_conductors(self.current_phase) ).values() LOG.debug( "%s launching environment conductor processes: %s", self, list(environment_conductors), ) ec_processes = [ Process( name=f"EnvironmentConductor-{ec.uid}", target=spawn_wrapper, args=( f"EnvironmentConductor-{ec.uid[-6:]}", RuntimeConfig().to_dict(), ec.run, # type: ignore[attr-defined] ), ) for ec in environment_conductors ] self._environment_conductors = [ ec.uid for ec in environment_conductors ] for p in ec_processes: p.start() return ec_processes @ESM.spawns async def _start_agent_conductors(self): agent_conductors = ( self.experiment_run.agent_conductors(self.current_phase) ).values() LOG.debug( "%s launching agent conductor processes: %s", self, list(agent_conductors), ) ac_processes = [ Process( name=f"AgentConductor-{ac.uid}", target=spawn_wrapper, args=( f"AgentConductor-{ac.uid[-6:]}", RuntimeConfig().to_dict(), ac.run, # type: ignore[attr-defined] ), ) for ac in agent_conductors ] self._agent_conductors = [ac.uid for ac in agent_conductors] for p in ac_processes: p.start() return ac_processes def _create_evaluation_sc(self) -> Optional[SimulationController]: sc = None try: assert self.experiment_run is not None orig_simulation_controllers = list( ( self.experiment_run.simulation_controllers( self.current_phase ) ).values() ) if len(orig_simulation_controllers) == 0: LOG.error( "Current phase %s must have at least one simulation controller, but none is present", self.current_phase, ) orig_sc: SimulationController = orig_simulation_controllers[0] agent_configurations = copy.deepcopy(orig_sc._agent_configurations) for agent_conf in agent_configurations.values(): # Load can be removed and there is no need for explicit load of brain, because the Muscle initially # gets the model from the currently operating brain. agent_conf.pop("load", None) # Remove replay as there is no need for in evaluation agent_conf.pop("replay", None) sc = orig_sc.__class__( agent_conductor_ids=orig_sc._agent_conductor_ids, environment_conductor_ids=orig_sc._environment_conductor_ids, agents=agent_configurations, mode=Mode.EVALUATE, termination_conditions=[ # Here, the MaxEpisodeTermCond is not needed, because this TermCond # is only used in the SC. The max episodes (of 1) for the eval worker # is handled by the method # _handle_simulation_controller_termination_request { "name": "palaestrai.experiment:EnvironmentTerminationCondition", "params": {}, } ], ) sc.uid += "-EVALUATE" except Exception: LOG.exception( "%s could not create and start a new simulation controller for evaluation", self, ) return sc async def _start_evaluation(self): scuid_dict = self._get_flow_status_mapping() if scuid_dict is not None and any( scuid.endswith("-EVALUATE") for scuid in scuid_dict[SimulationFlowControl.PAUSE] ): LOG.error( "There are already evaluation workers active, will not start further evaluation!" ) return LOG.info( "Starting evaluation after %d episodes " 'in phase %d of experiment run "%s".', sum(self.current_episode_counts.values()), self.current_phase + 1, self.experiment_run.uid, ) if self._future_simulation_controller_active is None: self._future_simulation_controller_active = ( asyncio.get_running_loop().create_future() ) try: simulation_controllers = [ self._create_evaluation_sc() for _ in range(self._worker) ] if any(sc is None for sc in simulation_controllers): LOG.error( "There were None evaluation simulation controllers created: %s", simulation_controllers, ) # The evaluation sc task does not have to be awaited for. If the simulation ends it is signalled by # the SimulationControllerTerminationRequest which is reacted on self._processes += await self._start_simulation_controllers( simulation_controllers ) _ = self._send_simulation_start_requests( simulation_controllers=[ sc.uid for sc in simulation_controllers ], phase_config={ "episodes": self.max_eval_episodes, "worker": self._worker, "mode": Mode.EVALUATE.name.lower(), }, ) await self._future_simulation_controller_active if ( self._future_simulation_controller_active.result() != "EVALUATE" ): LOG.error( "Continued not because all evaluation workers are active", ) self._future_simulation_controller_active = None except Exception as e: LOG.error( "%s could not create and start a new simulation controller for evaluation", self, ) raise e @ESM.requests def _send_simulation_start_requests( self, simulation_controllers: Optional[List[str]] = None, phase_config: Optional[Dict[str, Any]] = None, ): assert self.experiment_run is not None if simulation_controllers is None: simulation_controllers = self._simulation_controllers LOG.debug( "Send start requests to simulation controllers: %s", simulation_controllers, ) if phase_config is None: phase_config = self.experiment_run.phase_configuration( self.current_phase ) if not any( [sc_uid.endswith("-EVALUATE") for sc_uid in simulation_controllers] ): self._simulation_controllers_active = set() return [ SimulationStartRequest( sender_run_governor_id=self.uid, receiver_simulation_controller_id=sc_uid, experiment_run_id=self.experiment_run.uid, experiment_run_instance_id=self.experiment_run.instance_uid, experiment_run_phase=self.current_phase, experiment_run_phase_id=self.experiment_run.get_phase_name( self.current_phase, ), experiment_run_phase_configuration=phase_config, # Every worker has its own SimulationController and therefore gets its respective episode # For the continued training phase, the initial episode may already be advanced # when using an intermittent evaluation phase. # The _get_and_increase_eval_episode method is locked by the outer # _evaluation_start_lock around the start_evaluation method. episode=0, ) for sc_uid in simulation_controllers ] @ESM.on(SimulationStartResponse) def _handle_simulation_start_response( self, response: SimulationStartResponse ): LOG.debug( "%s got simulation start response from %s", self, response.sender ) assert self._simulation_controllers_flow_status is not None assert self._worker is not None self._simulation_controllers_active |= {response.sender} self._simulation_controllers_flow_status[response.sender] = ( SimulationFlowControl.CONTINUE ) if ( self._future_simulation_controller_active is not None and len( { sc_uid for sc_uid in self._simulation_controllers_active if sc_uid.endswith("-EVALUATE") } ) >= self._worker ): self._future_simulation_controller_active.set_result("EVALUATE") def _calc_worker_pause_offset(self): episodes_elapsed = sum( [ v for k, v in self.current_episode_counts.items() if not k.endswith("-EVALUATE") ] ) return ( self._evaluate_every is not None and (episodes_elapsed + min(self._worker, self._evaluate_every)) % self._evaluate_every ) def _has_more_eval_episodes(self): episodes_elapsed = { v for k, v in self.current_episode_counts.items() if k.endswith("-EVALUATE") } episodes_elapsed = ( max(episodes_elapsed) if len(episodes_elapsed) > 0 else 0 ) return episodes_elapsed < self.max_eval_episodes def _should_start_pausing_simulation_controllers(self): LOG.debug( "Relevant vars for determining to pause SCs: " "evaluate_every: %s " "_has_stop_phase_signal: %s " "worker: %s " "current_episode_counts: %s", self._evaluate_every, self._has_stop_phase_signal, self._worker, self.current_episode_counts, ) scuid_dict = self._get_flow_status_mapping() assert scuid_dict is not None all_eval_paused = not any( scuid for scuid in scuid_dict[SimulationFlowControl.CONTINUE] if scuid.endswith("-EVALUATE") ) return ( self._evaluate_every is not None and not self._has_stop_phase_signal and all_eval_paused and self._has_more_eval_episodes() and self._calc_worker_pause_offset() == 0 ) def _should_pause_all_workers(self): # Pause all workers that are "leftover" when worker > evaluate_every # In this case they do not get paused "on the go" as the evaluation should # be performed "in the middle" of a parallel episode of some workers return ( self._evaluate_every is not None and not self._has_stop_phase_signal and self._worker > self._evaluate_every and (self._evaluate_every - self._calc_worker_pause_offset()) == 1 ) def _compute_pausing_simulation_controller_signal(self): if self._pausing_simulation_controllers is None: LOG.error("Pausing simulation controller flag not set!") if ( not self._pausing_simulation_controllers and self._should_start_pausing_simulation_controllers() ): LOG.debug( "Starting to pause simulation controllers (current status: %s) (with the next worker) (%s)", self._simulation_controllers_flow_status, self, ) self._pausing_simulation_controllers = True # The current episode end of the worker only indicates # to pause all following workers return True return False async def _maybe_pause_simulation_controller(self, sc_uid: str): if self._compute_pausing_simulation_controller_signal(): return elif self._pausing_simulation_controllers: if sc_uid.endswith("-EVALUATE"): LOG.error( "The evaluation worker should not be paused by the " "_maybe_pause_simulation_controller method! " "Inconsistencies may occur later on!" ) if self._should_pause_all_workers(): scuid_dict = self._get_flow_status_mapping() assert scuid_dict is not None to_be_paused_scuids = scuid_dict[ SimulationFlowControl.CONTINUE ] if sc_uid not in to_be_paused_scuids: LOG.error( "Should pause all left over workers on termination requestion of " "simulation controller %s, but itself is computed not to be paused", sc_uid, ) else: to_be_paused_scuids = {sc_uid} await self._change_sc_flow_control( to_be_paused_scuids, SimulationFlowControl.PAUSE ) def _should_start_evaluation(self, mode: Mode): assert self._simulation_controllers_flow_status is not None all_workers_paused = all( status == SimulationFlowControl.PAUSE for scuid, status in self._simulation_controllers_flow_status.items() if not scuid.endswith("-EVALUATE") ) return ( not self._has_stop_phase_signal and not self._is_phase_shutting_down and mode != Mode.EVALUATE and all_workers_paused ) async def _change_sc_flow_control( self, scuids: Set[str], flow: SimulationFlowControl ): assert self._eval_flow_control_change_lock is not None assert self.experiment_run is not None LOG.debug( f"Change flow of simulation controllers (%s) to %s " f'in phase %d of experiment run "%s".', scuids, flow, self.current_phase + 1, self.experiment_run.uid, ) async with self._eval_flow_control_change_lock: if self._future_simulation_controllers_flow_change is not None: LOG.error( "%s leaks future (%s) for flow change from other changing request", self, self._future_simulation_controllers_flow_change, ) flow_status_mapping = self._get_flow_status_mapping() assert flow_status_mapping is not None continued_scuids = flow_status_mapping[ SimulationFlowControl.CONTINUE ] continued_scuids = continued_scuids.union(scuids) eval_scuids, normal_scuids = { scuid for scuid in continued_scuids if scuid.endswith("-EVALUATE") }, { scuid for scuid in continued_scuids if not scuid.endswith("-EVALUATE") } if len(eval_scuids) > 0 and len(normal_scuids) > 0: LOG.warning( "%s normal (%s) and evaluation (%s) worker will be run (continued) in parallel, but they usually should not.", self, normal_scuids, eval_scuids, ) self._future_simulation_controllers_flow_change = ( asyncio.get_running_loop().create_future() ) _ = self._request_simulation_flow_change(flow, list(scuids)) await self._future_simulation_controllers_flow_change self._future_simulation_controllers_flow_change = None async def _pause_eval_and_continue_worker_simulation_controllers( self, eval_sc_uid: str ): if not eval_sc_uid.endswith("-EVALUATE"): LOG.error( "Calling _pause_eval_and_continue_worker_simulation_controllers not with an evaluation worker!" ) scuid_dict = self._get_flow_status_mapping() if scuid_dict is None: LOG.error( "No paused simulation controllers found that could be continued!" ) return if eval_sc_uid not in scuid_dict[SimulationFlowControl.CONTINUE]: LOG.error( "Evaluation worker with uid %s should be paused, but it is not CONTINUEd", eval_sc_uid, ) if ( len(scuid_dict[SimulationFlowControl.CONTINUE]) == 0 or len(scuid_dict[SimulationFlowControl.PAUSE]) == 0 ): LOG.error( "No actual flip can be performed, " "because active (%s) or " "paused simulations controller list (%s) " "is empty.", scuid_dict[SimulationFlowControl.CONTINUE], scuid_dict[SimulationFlowControl.PAUSE], ) all_evaluation_workers_paused = ( len(scuid_dict[SimulationFlowControl.CONTINUE] - {eval_sc_uid}) == 0 ) # CAUTION: DO NOT BLINDLY REMOVE, THIS MAY BE RELEVANT FOR FUTURE USE # if all_evaluation_workers_paused: # await self._agent_evaluation() if not self._has_more_eval_episodes(): LOG.info( "No more evaluation episodes left, going to continue evaluation worker " "to shut them down." ) return await self._change_sc_flow_control( {eval_sc_uid}, SimulationFlowControl.PAUSE, ) if all_evaluation_workers_paused: self._compute_pausing_simulation_controller_signal() await self._change_sc_flow_control( { scuid for scuid in scuid_dict[SimulationFlowControl.PAUSE] if not scuid.endswith("-EVALUATE") }, SimulationFlowControl.CONTINUE, ) async def _continue_eval_worker_simulation_controllers(self): scuid_dict = self._get_flow_status_mapping() if scuid_dict is None: LOG.error( "No paused simulation controllers found that could be continued!" ) return if len(scuid_dict[SimulationFlowControl.CONTINUE]) != 0: LOG.error( "There are still simulation controllers running: %s!" "Starting the evaluation worker will lead to inconsistent states!", scuid_dict[SimulationFlowControl.CONTINUE], ) eval_paused_scuids = { scuid for scuid in scuid_dict[SimulationFlowControl.PAUSE] if scuid.endswith("-EVALUATE") } if len(eval_paused_scuids) != self._worker: LOG.error( "Not exactly worker-amount many paused evaluation worker found that could be continued, but %s!", eval_paused_scuids, ) await self._change_sc_flow_control( eval_paused_scuids, SimulationFlowControl.CONTINUE, ) async def _maybe_continue_simulation_controllers( self, ended_sc_uid: str, process_pid: int ): try: scuid_dict = self._get_flow_status_mapping() if scuid_dict is not None: if ( len( scuid_dict[SimulationFlowControl.CONTINUE] - {ended_sc_uid} ) == 0 ): paused_sc_uids = scuid_dict[ SimulationFlowControl.PAUSE ] - {ended_sc_uid} if ( all( scuid.endswith("-EVALUATE") for scuid in paused_sc_uids ) or not any( scuid.endswith("-EVALUATE") for scuid in paused_sc_uids ) ) and len(paused_sc_uids) > 0: await self._change_sc_flow_control( paused_sc_uids, SimulationFlowControl.CONTINUE, ) assert self._simulation_controllers_flow_status is not None # CAUTION: DO NOT BLINDLY REMOVE, THIS MAY BE RELEVANT FOR FUTURE USE # if ( # self._evaluate_every is not None # and len( # set(self._simulation_controllers_flow_status.keys()) # - {ended_sc_uid} # ) # == 0 # ): # await self._agent_evaluation() except Exception: LOG.exception( "Maybe-continuing simulation controllers failed after SC %s ended", ended_sc_uid, ) finally: if ended_sc_uid not in self._simulation_controllers_active: LOG.error( "The evaluation simulation controller %s is done and should logically be shutdown, " "but it is no longer listed as active (only active simulation controllers: %s).", ended_sc_uid, self._simulation_controllers_active, ) self._shutdown_simulation_controller(ended_sc_uid, process_pid) self._simulation_controllers_continuation_task = None @ESM.on(SimulationControllerTerminationRequest) async def _handle_simulation_controller_termination_request( self, request: SimulationControllerTerminationRequest ): LOG.debug( "%s, workers: %s", request, self._simulation_controllers_active ) assert self.experiment_run is not None assert self._termination_condition is not None if request.sender not in self._simulation_controllers_active: LOG.error( "Sender of termination request %s is no longer active, but only: %s", request.sender, self._simulation_controllers_active, ) # The async lock may technically not be needed, # but should indicate that the reading and writing # from and to the current_episode_counts have to be # transactional across the request handling tasks assert self._episode_counts_lock is not None async with self._episode_counts_lock: current_episode = self.current_episode_counts[request.sender] # The phase_flow_control is assumed to be called only once by the RunGovernor flow, _ = self._termination_condition.phase_flow_control( self, request ) if flow.value > SimulationFlowControl.CONTINUE.value: self.current_episode_counts[request.sender] += 1 LOG.debug( "Increased episode counts of %s (dict: %s, flow: %s)", request.sender, self.current_episode_counts, flow, ) # CONTINUE is a valid flow control result. Return early here: if flow == SimulationFlowControl.CONTINUE: return SimulationControllerTerminationResponse( sender_run_governor_id=request.receiver, receiver_simulation_controller_id=request.sender, experiment_run_instance_id=self.experiment_run.instance_uid, experiment_run_id=request.experiment_run_id, experiment_run_phase=self.current_phase, # Every worker has its own SimulationController # and, therefore, gets its respective episode episode=current_episode, restart=False, complete_shutdown=False, flow_control=flow, ) if request.mode == Mode.EVALUATE: # (2.) Pause current eval worker and continue normal workers when all eval workers are paused await self._pause_eval_and_continue_worker_simulation_controllers( request.sender ) else: await self._maybe_pause_simulation_controller(request.sender) if self._should_start_evaluation(request.mode): # Already paused all simulation controllers, because now evaluating. # So further simulation controllers needs to be paused and the signal, # set in _maybe_pause_simulation_controller, can be turned down self._pausing_simulation_controllers = False LOG.info( "Starting evaluation %s " "in phase %s of experiment run %s", None, # TODO: Get and check the current eval episode self.current_phase + 1, self.experiment_run.uid, ) if any( scuid.endswith("-EVALUATE") for scuid in self._simulation_controllers_active ): # (3.) If next evaluation should be performed, reuse already existing (active), # but paused evaluation workers. There cannot be used a unified flipping method # because here all workers are paused, other than at (2.), where the eval worker # is at CONTINUE await self._continue_eval_worker_simulation_controllers() else: # (1.) Start eval worker processes initially assert self._evaluation_start_lock is not None async with self._evaluation_start_lock: await self._start_evaluation() if flow.value <= SimulationFlowControl.RESTART.value: # Restart/soft-reset of this particular worker: LOG.info( 'Restarting simulation worker "%s" ' "in phase %d for episode %d " 'of experiment run "%s"', request.sender, self.current_phase + 1, current_episode + 1, self.experiment_run.uid, ) return SimulationControllerTerminationResponse( sender_run_governor_id=request.receiver, receiver_simulation_controller_id=request.sender, experiment_run_instance_id=self.experiment_run.instance_uid, experiment_run_id=request.experiment_run_id, experiment_run_phase=self.current_phase, # Every worker has its own SimulationController and therefore gets its respective episode episode=current_episode, restart=True, complete_shutdown=False, flow_control=flow, ) LOG.info( 'Signalling simulation worker "%s" to shut down ' 'for phase %d in experiment run "%s"', request.sender, self.current_phase + 1, # +1 for display purposes only. self.experiment_run.uid, ) # Potentially shut down all workers, so sending # SimulationShutdownRequest to all running simulation controllers, # also the SimController that requests the termination, because # itself did not or has not been shutdown yet if flow == SimulationFlowControl.STOP_PHASE: # Create task to ask next step, shutdown all conductors. # Its only a matter of time until the last one goes down... # We start this task here, but we try to wait for the # SC process to actually end using a future… # The future is set in the SIGCHLD event handler. LOG.info( "Stopping phase %d of experiment run %s", self.current_phase, self.experiment_run.uid, ) self._has_stop_phase_signal = True return SimulationControllerTerminationResponse( sender_run_governor_id=request.receiver, receiver_simulation_controller_id=request.sender, experiment_run_instance_id=self.experiment_run.instance_uid, experiment_run_id=request.experiment_run_id, experiment_run_phase=self.current_phase, # Every worker has its own SimulationController and therefore gets its respective episode episode=current_episode, restart=False, complete_shutdown=True, flow_control=flow, ) @ESM.requests def _request_simulation_flow_change( self, simulation_flow: SimulationFlowControl, simulation_controllers: List[str], ): assert self.experiment_run is not None self._simulation_controllers_pending_response = set( simulation_controllers ) return [ SimulationFlowChangeRequest( sender_run_governor_id=self.uid, receiver_simulation_controller_id=sc_uid, experiment_run_id=self.experiment_run.uid, experiment_run_instance_id=self.experiment_run.instance_uid, experiment_run_phase=self.current_phase, simulation_flow=simulation_flow, interrupt=True, overwrite_flow=False, ) for sc_uid in simulation_controllers ] @ESM.on(SimulationFlowChangeResponse) async def _handle_simulation_flow_change_response( self, response: SimulationFlowChangeResponse ): assert ( self._future_simulation_controllers_flow_change is not None and not self._future_simulation_controllers_flow_change.done() ), "Future for SC flow control change is not present!" assert self._simulation_controllers_flow_status is not None if ( response.requested_flow == self._simulation_controllers_flow_status.get( response.sender, None ) ): LOG.error( "Flow (%s) was not actually changed for simulation controller: %s", response.requested_flow, response.sender, ) if ( response.requested_flow != SimulationFlowControl.CONTINUE and response.requested_flow != SimulationFlowControl.PAUSE ): LOG.error( "Changed flow for %s is not CONTINUE and not PAUSE, but %s", response.sender, response.requested_flow, ) self._simulation_controllers_flow_status[response.sender] = ( response.requested_flow ) LOG.debug( "Noticed SimulationFlowChangeResponse (%s)." "_simulation_controllers_flow_status after processing: %s", response, self._simulation_controllers_flow_status, ) if ( response.sender not in self._simulation_controllers_pending_response ): LOG.error( "%s sent response for %s simulation flow change request, " "but is internally not pending for response!", response.sender, response.requested_flow, ) self._simulation_controllers_pending_response -= {response.sender} if len(self._simulation_controllers_pending_response) == 0: self._future_simulation_controllers_flow_change.set_result(True) def _get_flow_status_mapping( self, ) -> Optional[Dict[SimulationFlowControl, Set[str]]]: if self._simulation_controllers_flow_status is None: LOG.error("Flow status mapping is not present!") return None return { SimulationFlowControl.PAUSE: { scuid for scuid, status in self._simulation_controllers_flow_status.items() if status == SimulationFlowControl.PAUSE }, SimulationFlowControl.CONTINUE: { scuid for scuid, status in self._simulation_controllers_flow_status.items() if status == SimulationFlowControl.CONTINUE }, } def _check_reset_flow_status_mapping(self): if ( self._simulation_controllers_flow_status is not None and len(self._simulation_controllers_flow_status) > 0 ): LOG.error( "Current flow status mapping is not empty, but leaking: %s", self._simulation_controllers_flow_status, ) self._simulation_controllers_flow_status = {} self._pausing_simulation_controllers = ( self._should_start_pausing_simulation_controllers() ) async def _agent_evaluation(self): # TODO: Currently with DB query over config may lead to race condition, # because data in DB is only eventually accessible; without assert self._agent_evaluation_lock async with self._agent_evaluation_lock: if self._future_agent_evaluation is not None: LOG.error("Future agent evaluation is already present!") self._future_agent_evaluation = ( asyncio.get_running_loop().create_future() ) self._request_agent_evaluation() await self._future_agent_evaluation self._future_agent_evaluation = None @ESM.requests def _request_agent_evaluation(self): evaluation_episodes = { v for k, v in self.current_episode_counts.items() if k.endswith("-EVALUATE") } if len(evaluation_episodes) != 1: LOG.error( "Different evaluation episode counters: %s", self.current_episode_counts, ) if self._pending_agent_evaluation_responses is not None: LOG.error("Pending agent evaluation response is already present!") self._pending_agent_evaluation_responses = set(self._agent_conductors) return [ AgentEvaluationRequest( sender_run_governor=self.uid, receiver_agent_conductor=agent_conductor, experiment_run_uid=self.experiment_run.uid, experiment_run_instance_uid=self.experiment_run.instance_uid, experiment_run_phase=self.current_phase, # The evaluation episodes are the amount of episodes already computed, # so subtract one to get the index evaluation_episode=max(evaluation_episodes) - 1, ) for agent_conductor in self._agent_conductors ] @ESM.on(AgentEvaluationResponse) def _handle_agent_evaluation_response( self, response: AgentEvaluationResponse ): assert self._future_agent_evaluation is not None assert self._pending_agent_evaluation_responses is not None self._pending_agent_evaluation_responses -= {response.sender} if len(self._pending_agent_evaluation_responses) == 0: self._pending_agent_evaluation_responses = None self._future_agent_evaluation.set_result(True) async def _request_phase_shutdown(self): if not self._has_stop_phase_signal: LOG.error( "No STOP_PHASE signal was sent, but requesting to shut phase down" ) if self._is_phase_shutting_down: return self._is_phase_shutting_down = True if self._future_simulation_controller_active is not None: LOG.log( palaestrai.logging.ASYNCIO_LOG_LEVEL, "%s waits on future_simulation_controller_active", self, ) await self._future_simulation_controller_active LOG.info("Send requests for shutting down simulation controllers.") if len(self._simulation_controllers_active) > 0: _ = self._send_simulation_shutdown_requests() await self._simulation_controllers_down_event.wait() LOG.info("Send requests for shutting down conductors.") # Sync with the agent evaluation lock to make sure that the agent evaluation # is still performed before shutting down the conductors, # if the evaluation is still running when the shutdown is requested. async with self._agent_evaluation_lock: if len(self._agent_conductors) > 0: _ = self._send_agent_conductor_shutdown_requests() if len(self._environment_conductors) > 0: _ = self._send_environment_conductor_shutdown_requests() @ESM.requests def _send_agent_conductor_shutdown_requests(self): LOG.debug( "%s requesting shut down of %s", self, self._agent_conductors ) return [ ShutdownRequest( sender=self.uid, receiver=acuid, experiment_run_id=self.experiment_run.uid, experiment_run_instance_id=self.experiment_run.instance_uid, experiment_run_phase=self.current_phase, ) for acuid in self._agent_conductors ] @ESM.requests def _send_environment_conductor_shutdown_requests(self): LOG.debug( "%s requesting shutdown of %s", self, self._environment_conductors ) return [ ShutdownRequest( sender=self.uid, receiver=ecuid, experiment_run_id=self.experiment_run.uid, experiment_run_instance_id=self.experiment_run.instance_uid, experiment_run_phase=self.current_phase, ) for ecuid in self._environment_conductors ] @ESM.on(ShutdownResponse) def _handle_shutdown_response(self, response: ShutdownResponse): """Tracks shutdown acknowledgments and completes shutdown when all stop""" self._agent_conductors = [ acuid for acuid in self._agent_conductors if not acuid == response.sender ] self._environment_conductors = [ ecuid for ecuid in self._environment_conductors if not ecuid == response.sender ] LOG.debug( "%s got %s, conductors still up: %s", self, response, self._agent_conductors + self._environment_conductors, ) if ( len(self._agent_conductors) + len(self._environment_conductors) == 0 ): assert self._future_next_phase is not None assert self._future_shutdown is not None try: self._future_next_phase.set_result(True) self._future_shutdown.set_result(True) except InvalidStateError: pass # Sometimes when hitting Ctrl-C, there's a double stop. @ESM.requests def _send_simulation_shutdown_requests(self): LOG.debug( "%s sending SimulationShutdownRequest(s) to %s", self, self._simulation_controllers_active, ) return [ SimulationShutdownRequest( sender=self.uid, receiver=sc_uid, experiment_run_id=self.experiment_run.uid, experiment_run_instance_id=self.experiment_run.instance_uid, experiment_run_phase=self.current_phase, ) for sc_uid in self._simulation_controllers_active ] def _shutdown_simulation_controller(self, sc_uid: str, process_pid: int): LOG.info("Shutting down simulation controller %s", sc_uid) self._simulation_controllers_active -= {sc_uid} if ( self._simulation_controllers_flow_status is not None and sc_uid in self._simulation_controllers_flow_status ): del self._simulation_controllers_flow_status[sc_uid] if not any( scuid.endswith("-EVALUATE") for scuid in self._simulation_controllers_active ): LOG.debug("No evaluation workers are active") self._remove_simulation_controller_process(process_pid) # The check on the evaluation_start_lock is need when the normal SCs are done but still # evaluations has to be performed assert self._evaluation_start_lock is not None if ( len(self._simulation_controllers_active) != 0 or self._evaluation_start_lock.locked() ): return LOG.info("No simulation controllers are active any longer") assert self._simulation_controllers_down_event is not None self._simulation_controllers_down_event.set() self._shutdown_conductors_task = asyncio.create_task( self._request_phase_shutdown() ) def _remove_simulation_controller_process(self, process_pid: int): self._processes = [p for p in self._processes if p.pid != process_pid] if ( len( [ p for p in self._processes if p.name.startswith("SimulationController-") ] ) == 0 ): if len(self._simulation_controllers_active) > 0: LOG.error( "All simulation controller processes end, " "but some are still logically active (%s)", self._simulation_controllers_active, ) self._simulation_controllers_active = set() @ESM.on(SimulationShutdownResponse) def _handle_simulation_shutdown_response( self, response: SimulationShutdownResponse ): LOG.debug("Got %s", response) @ESM.on(ExperimentRunShutdownRequest) async def _handle_shutdown_request( self, request: ExperimentRunShutdownRequest ): LOG.info( 'Shutdown of experiment run "%s" requested', self.experiment_run.uid if self.experiment_run else "(no run)", ) await self._request_phase_shutdown() if self._phase_launcher_task is not None: self._phase_launcher_task.cancel() self.stop() # type: ignore[attr-defined] return ExperimentRunShutdownResponse( sender_run_governor_id=self.uid, receiver_executor_id=request.sender, experiment_run_id=request.experiment_run_id, successful=True, error=None, ) @ESM.on(ErrorIndicator) def _raise_error_indicator(self, error: ErrorIndicator): self._state = BasicState.ERROR LOG.critical( "%s received error from %s: %s", self, error.sender, error.error_message, ) # TODO: Shutdown everything if error.exception is not None: raise error.exception raise RuntimeError(error.error_message) @ESM.on(signal.SIGCHLD) def _handle_child( self, process: Union[aiomultiprocess.Process, multiprocessing.Process] ): LOG.debug("%s saw process %s end.", self, process.name) if process.exitcode != 0: self._state = BasicState.ERROR LOG.critical( "Subprocess %s exited prematurely with rc %s). " "Cannot continue simulation.", process.name, process.exitcode, ) self.stop( # type: ignore[attr-defined] RuntimeError( f"Subprocess {process.name} ended prematurely " f"with rc {process.exitcode}" ) ) if process.name.startswith("SimulationController-"): simulation_controller_name = process.name.removeprefix( "SimulationController-" ) LOG.debug( "SimulationController %s has signalled that it is has finished.", simulation_controller_name, ) assert isinstance(process.pid, int) self._simulation_controllers_continuation_task = ( asyncio.create_task( self._maybe_continue_simulation_controllers( simulation_controller_name, process.pid ) ) ) def __str__(self): return ( f"RunGovernor(uid={self.uid}, state={self._state.name}, " f"experiment_run_uid=" f"{self.experiment_run.uid if self.experiment_run is not None else '(None)'}, " f"phase={self.current_phase})" )