Source code for palaestrai.experiment.max_episodes_termination_condition

from __future__ import annotations

import copy
from typing import TYPE_CHECKING, Optional, Union, Tuple, Any

import logging

from .termination_condition import TerminationCondition
from ..core.protocol import SimulationControllerTerminationRequest
from ..types import SimulationFlowControl, Mode

if TYPE_CHECKING:
    from palaestrai.experiment import RunGovernor

LOG = logging.getLogger(__name__)


[docs] class MaxEpisodesTerminationCondition(TerminationCondition): """Checks whether a maximum number of episodes has been exceeded. This termination condition will only trigger on phase level. It uses the ``episodes`` key in the phase configuration to check whether a maximum number of episodes has been reached. Examples -------- Consider the following experiment phase definition:: schedule: Training: phase_config: mode: train worker: 2 episodes: 100 simulation: conditions: - name: palaestrai.experiment:MaxEpisodesTerminationCondition params: {} name: palaestrai.simulation:TakingTurns run_config: condition: name: palaestrai.experiment:MaxEpisodesTerminationCondition params: {} Then, the phase would end when both workers (``worker: 2``) have reached 100 episodes (``episodes: 100``). """ def phase_flow_control( self, run_governor: RunGovernor, message: SimulationControllerTerminationRequest, ) -> Tuple[SimulationFlowControl, Any]: if not isinstance(message, SimulationControllerTerminationRequest): return SimulationFlowControl.CONTINUE, None if run_governor.experiment_run is None: LOG.warning( "MaxEpisodesTerminationCondition cannot control flow: " "Run governor has no experiment run object!" ) return SimulationFlowControl.CONTINUE, None try: max_episodes = run_governor.experiment_run.get_episodes( run_governor.current_phase ) worker = run_governor.experiment_run.get_worker( run_governor.current_phase ) evaluate_every = run_governor.experiment_run.get_evaluate_every( run_governor.current_phase ) assert max_episodes is not None assert worker is not None sum_max_episodes = max_episodes * worker evaluations = None if evaluate_every is not None: evaluations = (sum_max_episodes // evaluate_every) * worker sum_max_episodes += evaluations except KeyError: # If the current phase does not define a phase limit, we can # continue indefinitely. return SimulationFlowControl.CONTINUE, None # Not continuing, thus at least the episode is done # (the current_episode_counts will be calculated in the RunGovernor) # so calculate to STOP_PHASE or to only STOP_SIMULATION new_episode_counts = copy.deepcopy(run_governor.current_episode_counts) new_episode_counts[message.sender] += 1 # If all SCs have reached the max number of episodes, indicate end of # the phase: if sum(new_episode_counts.values()) >= sum_max_episodes: if not all( x >= max_episodes for sc_name, x in new_episode_counts.items() if not sc_name.endswith("-EVALUATE") ): LOG.error( "%s computed STOP_PHASE, " "but not all SimulationControllers are done with their episodes! " "(current_episode_counts: %s, sum_max_episodes: %s, max_episodes: %s, mode: %s)", self, new_episode_counts, sum_max_episodes, max_episodes, message.mode, stack_info=True, ) sum_evaluations = sum( [ x for sc_name, x in new_episode_counts.items() if sc_name.endswith("-EVALUATE") ] ) if evaluations is not None: if evaluations > sum_evaluations: LOG.error( "%s computed STOP_PHASE, but not all evaluations were performed!", self, ) elif evaluations < sum_evaluations: LOG.error( "%s noticed internal error: Too many evaluations were performed!", self, ) return SimulationFlowControl.STOP_PHASE, None # If only the current one, indicate shutdown of the current simulation # controller: sc_uid = message.sender if sc_uid.endswith("-EVALUATE"): assert run_governor.max_eval_episodes is not None if run_governor.max_eval_episodes <= 0: LOG.error( "Internal error: Max eval episodes should be calculated when " "evaluation workers are present and MaxEpisodesTerminationCondition is applied!" ) if new_episode_counts[sc_uid] < run_governor.max_eval_episodes: return SimulationFlowControl.RESTART, None if new_episode_counts[sc_uid] >= run_governor.max_eval_episodes: return SimulationFlowControl.STOP_SIMULATION, None if new_episode_counts[sc_uid] >= max_episodes: return SimulationFlowControl.STOP_SIMULATION, None # Default case return SimulationFlowControl.CONTINUE, None