Source code for palaestrai.experiment.run_governor

from __future__ import annotations
from typing import TYPE_CHECKING, List, Dict, Set, Optional, Union

import uuid
import signal
import logging
import asyncio
from collections import defaultdict

from aiomultiprocess import Process

from palaestrai.core.protocol import (
    ExperimentRunStartRequest,
    ExperimentRunStartResponse,
    SimulationStartRequest,
    SimulationStartResponse,
    SimulationControllerTerminationRequest,
    SimulationControllerTerminationResponse,
    ErrorIndicator,
    SimulationShutdownRequest,
    SimulationShutdownResponse,
    ExperimentRunShutdownRequest,
    ExperimentRunShutdownResponse,
    ShutdownRequest,
    ShutdownResponse,
)
from palaestrai.util import spawn_wrapper
from palaestrai.types import SimulationFlowControl
from palaestrai.core import EventStateMachine as ESM
from palaestrai.core import BasicState, RuntimeConfig
from palaestrai.util.exception import ExperimentAlreadyRunningError

if TYPE_CHECKING:
    import aiomultiprocess
    import multiprocessing
    from palaestrai.experiment import TerminationCondition
    from palaestrai.experiment.experiment_run import ExperimentRun

LOG = logging.getLogger(__name__)


[docs] @ESM.monitor(is_mdp_worker=True) class RunGovernor: """ This class implements the Run-Governor. Upon receiving requests from the executor, a RunGovernor instance handles a single experiment run by starting it, initialize the simulation controllers, the environment and the agent conductors, and, finally, shutting the experiment run down. The RunGovernor is implemented as state machine and this class provides the context for the distinct state classes. A freshly initialized RunGovernor waits in the state PRISTINE until the run method is called by the executor. See the distinct state classes for more information. Parameters ---------- uid : str The universally unique ID that identifies this run governor Attributes ---------- uid: str The UUID of this RunGovernor termination_condition: :class:`.TerminationCondition` A reference to the TerminationCondition instance. run_broker: :class:`.MajorDomoBroker` The broker for the communication with the simulation controller, the agents, and the environments. experiment_run_id: str The UUID of the current experiment run. tasks: List[aiomultiprocess.Process] A list of tasks the RunGovernor has started and that it has to shutdown in the end. worker: :class:`.MajorDomoWorker` The major domo worker for handling incoming requests client: :class:`.MajorDomoClient` The major domo client for sending requests to other workers. shutdown: bool The major kill switch of the RunGovernor. Setting this to false will stop the RunGovernor after the current state. state: :class:`.RunGovernorState` Holds the current state instance. The first state is PRISTINE. errors: List[Exception] A list that is used to collect errors raised in the states. """ def __init__(self, uid: Optional[str] = None): self.uid = uid if uid else "RunGovernor-%s" % str(uuid.uuid4()) self._state = BasicState.PRISTINE # Experiment control: self.current_phase: int = 0 self.current_episode_counts: Dict[str, int] = {} self.experiment_run: Optional[ExperimentRun] = None self._termination_condition: Optional[TerminationCondition] = None # Receiver information: self._simulation_controllers: List[str] = [] self._environment_conductors: List[str] = [] self._agent_conductors: List[str] = [] # Receiver synchronization: self._simulation_controllers_active: Set[str] = set() self._future_simulation_controllers_down: Optional[asyncio.Future] = ( None ) self._future_shutdown: Optional[asyncio.Future] = None self._future_next_phase: Optional[asyncio.Future] = None self._phase_launcher_task: Optional[asyncio.Task] = None self._shutdown_conductors_task: Optional[asyncio.Task] = None # Subprocess management: self._processes: List[aiomultiprocess.Process] = []
[docs] def setup(self): self._state = BasicState.PRISTINE self.mdp_service = self.uid # type: ignore[attr-defined] LOG.debug("%s ready.", self) self._state = BasicState.INITIALIZING
@ESM.on(ExperimentRunStartRequest) async def _handle_experiment_run_start_request( self, request: ExperimentRunStartRequest ): if self.experiment_run is not None: LOG.warning( '%s got request to start experiment run "%s", ' 'but experiment run "%s" is already running. ' "Reporting error and continueing. Elect somebody " "else instead!", self, self.experiment_run.uid, request.experiment_run.uid, ) self.stop() # type: ignore[attr-defined] return ExperimentRunStartResponse( sender_run_governor_id=request.receiver, receiver_executor_id=request.sender, experiment_run_id=request.experiment_run_id, experiment_run=request.experiment_run, successful=False, error=ExperimentAlreadyRunningError(self.experiment_run), ) self.current_phase = 0 self.experiment_run = request.experiment_run LOG.debug("%s setting up %s", self, self.experiment_run) try: await self._setup_run() self._state = BasicState.INITIALIZED except Exception as e: LOG.exception("%s encountered exception during setup: %s", self, e) self.stop(e) # type: ignore[attr-defined] return ExperimentRunStartResponse( sender_run_governor_id=request.receiver, receiver_executor_id=request.sender, experiment_run_id=request.experiment_run_id, experiment_run=request.experiment_run, successful=False, error=e, ) self._phase_launcher_task = asyncio.create_task(self._run_all_phases()) return ExperimentRunStartResponse( sender_run_governor_id=request.receiver, receiver_executor_id=request.sender, experiment_run_id=request.experiment_run_id, experiment_run=request.experiment_run, successful=True, error=None, ) async def _run_all_phases(self): self._state = BasicState.RUNNING n_phases = self.experiment_run.num_phases while self.current_phase < n_phases: self._future_next_phase = ( asyncio.get_running_loop().create_future() ) self._future_simulation_controllers_down = ( asyncio.get_running_loop().create_future() ) self._future_shutdown = asyncio.get_running_loop().create_future() LOG.debug( 'Setting up phase %d/%d in experiment run "%s"...', self.current_phase + 1, # +1 for display purposes n_phases, self.experiment_run.uid, ) await self._setup_phase() LOG.debug( "%s sending start request(s) to simulation controllers...", self, ) ssrq = self._send_simulation_start_requests() LOG.debug("%s sent simulation start requests: %s", self, ssrq) LOG.debug("%s waiting for phase to end", self) await self._future_next_phase # Wait for all processes to really end: LOG.debug( "%s waiting for processes to end: %s", self, self._processes ) for p in self._processes: await p.join() self.current_phase += 1 LOG.info( 'Executed all phases in run "%s", shutting down.', self.experiment_run.uid, ) try: await asyncio.wait_for(self._future_shutdown, timeout=15) except TimeoutError: LOG.error( "%s timed out while waiting for all processes to end. " "Conductors still active: %s", self, self._agent_conductors + self._environment_conductors, ) await self._request_simulation_controllers_shutdown() self.stop() # type: ignore[attr-defined] async def _setup_run(self): self.experiment_run.setup( f"tcp://127.0.0.1:{RuntimeConfig().executor_bus_port}" ) assert ( self.experiment_run.run_governor_termination_condition is not None ) self._termination_condition = ( self.experiment_run.run_governor_termination_condition ) async def _setup_phase(self): self.current_episode_counts = defaultdict(int) # Default value: 0 ps = await asyncio.gather( self._start_environment_conductors(), self._start_agent_conductors(), self._start_simulation_controllers(), ) # Returns a nested list self._processes = [p for pl in ps for p in pl] # Flatten "ps" list LOG.debug( "%s has processes this phase: %s", self, [p.name for p in self._processes], ) @ESM.spawns async def _start_simulation_controllers(self): simulation_controllers = ( self.experiment_run.simulation_controllers(self.current_phase) ).values() LOG.debug( "%s lauching simulation controller processes: %s", self, list(simulation_controllers), ) sc_processes = [ Process( name=f"SimulationController-{sc.uid}", target=spawn_wrapper, args=( f"SimulationController-{sc.uid[-6:]}", RuntimeConfig().to_dict(), sc.run, ), ) for sc in simulation_controllers ] self._simulation_controllers = [ sc.uid for sc in simulation_controllers ] for p in sc_processes: p.start() return sc_processes @ESM.spawns async def _start_environment_conductors(self): environment_conductors = ( self.experiment_run.environment_conductors(self.current_phase) ).values() LOG.debug( "%s lauching environment conductor processes: %s", self, list(environment_conductors), ) ec_processes = [ Process( name=f"EnvironmentConductor-{ec.uid}", target=spawn_wrapper, args=( f"EnvironmentConductor-{ec.uid[-6:]}", RuntimeConfig().to_dict(), ec.run, ), ) for ec in environment_conductors ] self._environment_conductors = [ ec.uid for ec in environment_conductors ] for p in ec_processes: p.start() return ec_processes @ESM.spawns async def _start_agent_conductors(self): agent_conductors = ( self.experiment_run.agent_conductors(self.current_phase) ).values() LOG.debug( "%s lauching agent conductor processes: %s", self, list(agent_conductors), ) ac_processes = [ Process( name=f"AgentConductor-{ac.uid}", target=spawn_wrapper, args=( f"AgentConductor-{ac.uid[-6:]}", RuntimeConfig().to_dict(), ac.run, ), ) for ac in agent_conductors ] self._agent_conductors = [ac.uid for ac in agent_conductors] for p in ac_processes: p.start() return ac_processes @ESM.requests def _send_simulation_start_requests(self): self._simulation_controllers_active = set() return [ SimulationStartRequest( sender_run_governor_id=self.uid, receiver_simulation_controller_id=sc_uid, experiment_run_id=self.experiment_run.uid, experiment_run_instance_id=self.experiment_run.instance_uid, experiment_run_phase=self.current_phase, experiment_run_phase_id=self.experiment_run.get_phase_name( self.current_phase ), experiment_run_phase_configuration=self.experiment_run.phase_configuration( self.current_phase ), ) for sc_uid in self._simulation_controllers ] @ESM.on(SimulationStartResponse) def _handle_simulation_start_response( self, response: SimulationStartResponse ): LOG.debug( "%s got simulation start response from %s", self, response.sender ) self._simulation_controllers_active |= {response.sender} @ESM.on(SimulationControllerTerminationRequest) async def _handle_simulation_controller_termination_request( self, request: SimulationControllerTerminationRequest ): LOG.debug( "%s, workers: %s", request, self._simulation_controllers_active ) assert self.experiment_run is not None self.current_episode_counts[request.sender] += 1 assert self._termination_condition is not None flow = self._termination_condition.phase_flow_control(self, request) if flow.value <= SimulationFlowControl.RESTART.value: # Restart/soft-reset of this particular worker: LOG.info( 'Restarting simulation worker "%s" ' 'in phase %d of experiment run "%s"', request.sender, self.current_phase + 1, self.experiment_run.uid, ) return SimulationControllerTerminationResponse( sender_run_governor_id=request.receiver, receiver_simulation_controller_id=request.sender, experiment_run_instance_id=self.experiment_run.instance_uid, experiment_run_id=request.experiment_run_id, experiment_run_phase=self.current_phase, restart=(flow.value <= SimulationFlowControl.RESTART.value), complete_shutdown=False, flow_control=flow, ) LOG.info( 'Signalling simulation worker "%s" to shut down ' 'for phase %d in experiment run "%s"', request.sender, self.current_phase + 1, # +1 for display purposes only. self.experiment_run.uid, ) # Potentially shut down all workers, so sending # SimulationShutdownRequest to all running simulation controllers, # also the SimController that requests the termination, because # itself did not or has not been shutdown yet if flow == SimulationFlowControl.STOP_PHASE: # Create task to ask next step, shutdown all conductors. # Its only a matter of time until the last one goes down... # We start this task here, but we try to wait for the # SC process to actually end using a futureā€¦ # The future is set in the SIGCHLD event handler. self._simulation_controllers_active -= {request.sender} self._shutdown_conductors_task = asyncio.create_task( self._request_conductors_shutdown() ) return SimulationControllerTerminationResponse( sender_run_governor_id=request.receiver, receiver_simulation_controller_id=request.sender, experiment_run_instance_id=self.experiment_run.instance_uid, experiment_run_id=request.experiment_run_id, experiment_run_phase=self.current_phase, restart=False, complete_shutdown=True, flow_control=flow, ) async def _request_conductors_shutdown(self): LOG.debug("%s waiting to shut down all conductors...", self) if self._future_simulation_controllers_down is not None: await self._future_simulation_controllers_down _ = self._send_agent_conductor_shutdown_requests() _ = self._send_environment_conductor_shutdown_requests() async def _request_simulation_controllers_shutdown(self): _ = self._send_simulation_shutdown_requests() @ESM.requests def _send_agent_conductor_shutdown_requests(self): LOG.debug( "%s requesting shut down of %s", self, self._agent_conductors ) return [ ShutdownRequest( sender=self.uid, receiver=acuid, experiment_run_id=self.experiment_run.uid, experiment_run_instance_id=self.experiment_run.instance_uid, experiment_run_phase=self.current_phase, ) for acuid in self._agent_conductors ] @ESM.requests def _send_environment_conductor_shutdown_requests(self): LOG.debug( "%s requesting shutdown of %s", self, self._environment_conductors ) return [ ShutdownRequest( sender=self.uid, receiver=ecuid, experiment_run_id=self.experiment_run.uid, experiment_run_instance_id=self.experiment_run.instance_uid, experiment_run_phase=self.current_phase, ) for ecuid in self._environment_conductors ] @ESM.on(ShutdownResponse) def _handle_shutdown_response(self, response: ShutdownResponse): self._agent_conductors = [ acuid for acuid in self._agent_conductors if not acuid == response.sender ] self._environment_conductors = [ ecuid for ecuid in self._environment_conductors if not ecuid == response.sender ] LOG.debug( "%s got %s, conductors still up: %s", self, response, self._agent_conductors + self._environment_conductors, ) if ( len(self._agent_conductors) + len(self._environment_conductors) == 0 ): assert self._future_next_phase is not None assert self._future_shutdown is not None self._future_next_phase.set_result(True) self._future_shutdown.set_result(True) @ESM.requests def _send_simulation_shutdown_requests(self): LOG.debug( "%s sending SimulationShutdownRequest(s) to %s", self, self._simulation_controllers_active, ) if self._future_simulation_controllers_down is None: # This can happen if we get this before a phase is started. # In this case, we're probably not returning any shutdown # request, as there's simply no simulation controller active. # Still, we have to stay consistent: self._future_simulation_controllers_down = ( asyncio.get_running_loop().create_future() ) return [ SimulationShutdownRequest( sender=self.uid, receiver=sc_uid, experiment_run_id=self.experiment_run.uid, experiment_run_instance_id=self.experiment_run.instance_uid, experiment_run_phase=self.current_phase, ) for sc_uid in self._simulation_controllers_active ] @ESM.on(SimulationShutdownResponse) def _handle_simulation_shutdown_response( self, response: SimulationShutdownResponse ): self._simulation_controllers_active -= {response.sender} if len(self._simulation_controllers_active) == 0: assert self._future_simulation_controllers_down is not None self._future_simulation_controllers_down.set_result(True) @ESM.on(ExperimentRunShutdownRequest) async def _handle_shutdown_request( self, request: ExperimentRunShutdownRequest ): LOG.info( 'Shutdown of experiment run "%s" requested', self.experiment_run.uid if self.experiment_run else "(no run)", ) if self._phase_launcher_task is not None: self._phase_launcher_task.cancel() # Create the future to make sure we have not dead tasks: if not self._future_simulation_controllers_down: self._future_simulation_controllers_down = ( asyncio.get_running_loop().create_future() ) await self._request_simulation_controllers_shutdown() await self._future_simulation_controllers_down self.stop() # type: ignore[attr-defined] return ExperimentRunShutdownResponse( sender_run_governor_id=self.uid, receiver_executor_id=request.sender, experiment_run_id=request.experiment_run_id, successful=True, error=None, ) @ESM.on(ErrorIndicator) def _raise_error_indicator(self, error: ErrorIndicator): self._state = BasicState.ERROR LOG.critical( "%s received error from %s: %s", self, error.sender, error.error_message, ) # TODO: Shutdown everything if error.exception is not None: raise error.exception raise RuntimeError(error.error_message) @ESM.on(signal.SIGCHLD) def _handle_child( self, process: Union[aiomultiprocess.Process, multiprocessing.Process] ): LOG.debug("%s saw process %s end.", self, process.name) if process.exitcode != 0: self._state = BasicState.ERROR LOG.critical( "Subprocess %s exited prematurely with rc %s). " "Cannot continue simulation.", process.name, process.exitcode, ) self.stop( # type: ignore[attr-defined] RuntimeError( f"Subprocess {process.name} ended prematurely " f"with rc {process.exitcode}" ) ) self._processes = [p for p in self._processes if p.pid != process.pid] if ( process.name.startswith("SimulationController-") and len( [ p for p in self._processes if p.name.startswith("SimulationController-") ] ) == 0 ): LOG.debug("%s saw last simulation controller end.", self) # If this was the last SC, we can shutdown all conductors. # To let the respective task continue, set this future: assert self._future_simulation_controllers_down is not None self._future_simulation_controllers_down.set_result(True) def __str__(self): return ( f"RunGovernor(uid={self.uid}, state={self._state.name}, " f"experiment_run_uid=" f"{self.experiment_run.uid if self.experiment_run is not None else '(None)'}, " f"phase={self.current_phase})" )