Source code for palaestrai.store.receiver

from __future__ import annotations
from typing import TYPE_CHECKING, Any, Dict, Tuple, Callable, Optional

import re
import uuid
import queue
import signal
import inspect
import asyncio
import logging
import datetime
from collections import deque

import jsonpickle
import jsonpickle.ext.numpy as jsonpickle_numpy
import ruamel.yaml as yml
from numpy.random import RandomState
import sqlalchemy
import sqlalchemy.engine
import sqlalchemy.exc
import sqlalchemy.orm
from sqlalchemy import select, Text, func
from sqlalchemy.orm.attributes import flag_modified

import palaestrai.core.MDP as MDP
import palaestrai.core.protocol as proto
from palaestrai.types import SimulationFlowControl
from palaestrai.core.runtime_config import RuntimeConfig
from palaestrai.core.serialisation import deserialize
from . import database_model as dbm

LOG = logging.getLogger(__name__)


[docs] class StoreReceiver: """The message receiver of the palaestrAI store. The store hooks into the global communication, reading every message that is being exchanged between :class:`Executor`, :class:`RunGovernor`, :class:`AgentConductor`, :class:`Environment`, :class:`Brain`, and :class:`Muscle` instances. From these messages, it reads all relevant status information in order to relay them to the store database for later analysis of experiments. """ MAX_RETRIES = 5 _SIMTIMES_ENVKEY_RE = re.compile(r"\.(.*)-[^-]*\Z") SHUTDOWN_SENTINEL = None def __init__(self, message_queue: queue.Queue): self._running = True self._uid = uuid.uuid4() self._incoming_queue = message_queue self._inflight_cache: deque = deque() self._incoming_task: Optional[asyncio.Task] = None self._db_engine = None self._db_session_maker: sqlalchemy.orm.sessionmaker | None = None self._db_session: sqlalchemy.orm.Session | None = None self._message_dispatch: dict[Any, Callable | None] = { v: None for k, v in proto.__dict__.items() if ( inspect.isclass(v) and (k.endswith("Request") or k.endswith("Response")) ) } self._message_dispatch.update( { proto.ExperimentRunStartRequest: self._write_experiment, proto.SimulationStartRequest: self._write_experiment_run_phase, proto.EnvironmentSetupResponse: self._write_environment, proto.EnvironmentStartResponse: self._write_static_state, proto.EnvironmentResetResponse: self._reset_environment, proto.EnvironmentUpdateResponse: self._write_world_state, proto.AgentSetupRequest: self._write_agent, proto.AgentSetupResponse: self._write_muscles, proto.MuscleUpdateRequest: self._write_muscle_actions, proto.SimulationControllerTerminationResponse: self._invalidate_cache, } ) try: self._store_uri = RuntimeConfig().store_uri if not self._store_uri: raise KeyError except KeyError: LOG.error( "The storage subsystem has no store_uri configured, " "I'm going to disable myself. :-( " "If you want to employ me, set the 'store_uri' runtime " "configuration parameter.", ) self.disable() # Caches to avoid lookup queries: self._environment_ticks: Dict[Tuple, int] = {} self._known_agents: Dict[Tuple, int] = {} self._known_environments: Dict[Tuple, int] = {}
[docs] def disable(self): """Disables the store completely.""" LOG.debug("Disabling the storage backend.") for k in self._message_dispatch.keys(): # Disable all handlers. self._message_dispatch[k] = None if self._db_session: # Explicitly close session here or we will see "session used in # wrong thread" errors, because the garbage collector runs in a # different thread. self._db_session.close() self._db_session = None self._inflight_cache.clear()
@property def uid(self): return self._uid @property def _dbh(self) -> sqlalchemy.orm.session: if self._db_engine is None: self._db_engine = sqlalchemy.create_engine( RuntimeConfig().store_uri, json_serializer=jsonpickle.dumps, json_deserializer=jsonpickle.loads, ) self._db_session_maker = sqlalchemy.orm.sessionmaker() self._db_session_maker.configure(bind=self._db_engine) if self._db_session is None: try: self._db_session = self._db_session_maker() LOG.debug( "%s connected to: %s", self, RuntimeConfig().store_uri, ) except ( sqlalchemy.exc.OperationalError, sqlalchemy.exc.ArgumentError, ) as e: LOG.error( "%s could not connect to %s: %s. " "I'm going to say good-bye to this cruel world now!", self, RuntimeConfig().store_uri, e, ) self.disable() return self._db_session @property def _is_enabled(self): return not all(x is None for x in self._message_dispatch.values()) async def _maybe_commit(self, force: bool = False): """Database commit handling Takes care of commiting elements to the database, yielding retries if necessary. This method will also handle transactions and caching unless a commit is forced. Parameters ---------- force : boolen, default: False Instructs the method to force the commit, regardless of any cache settings """ assert self._dbh is not None # There's a magic number here. # The reason is simply buffering. Writing out every single update is # too expensive in terms of I/O. Buffering all doesn't work as well. # So we keep a small amount of updates and write them out in bulk. # Just enough to be more efficient, but not so much as to cause # memory issues. # The number is just an educated guess, really. if ( len(self._dbh.new) < RuntimeConfig().store_buffer_size and not force ): return LOG.debug( "%s committing %d items to the database", self, len(self._dbh.new) ) self._dbh.commit() self._inflight_cache.clear() async def _read_next_incoming_message(self): try: return await asyncio.get_running_loop().run_in_executor( None, self._incoming_queue.get ) except ValueError: # Queue might be closed on shutdown return None
[docs] async def run(self): """Run the store.""" asyncio.get_running_loop().add_signal_handler( signal.SIGINT, self._interrupt ) asyncio.get_running_loop().add_signal_handler( signal.SIGTERM, self._terminate ) jsonpickle_numpy.register_handlers() jsonpickle.set_preferred_backend("simplejson") jsonpickle.set_encoder_options("simplejson", ignore_nan=True) LOG.info("Connecting to database at %s", RuntimeConfig().store_uri) retries = 0 read_retries = 0 while self._running or len(self._inflight_cache) > 0: await asyncio.sleep( min(2, (2**read_retries - 1)) * 0.1 ) # Yield to event loop for signals if retries > StoreReceiver.MAX_RETRIES: LOG.critical( "%s cannot write to the database after %d retries: I will disable myself now.", self, StoreReceiver.MAX_RETRIES, ) self.disable() self._inflight_cache.clear() continue if ( retries > 0 or not self._running and len(self._inflight_cache) > 0 ): # We need to retry, so sleep for a bit and then try again: await asyncio.sleep(2**retries - 1) try: for msg in self._inflight_cache.copy(): self._inflight_cache.append(msg) await self.write(msg) retries = 0 # If all of this worked, we can reset await self._maybe_commit(force=True) LOG.debug( "%s was successful retrying, %d items left", self, len(self._inflight_cache), ) continue except sqlalchemy.exc.DBAPIError: # :-( retries += 1 continue try: msg = self._incoming_queue.get_nowait() read_retries = 0 except ( AttributeError, ValueError, ): # Queue might be closed on shutdown break except ( queue.Empty, TimeoutError, ): # Nothing to see here, loop and try again read_retries += 1 continue if msg is StoreReceiver.SHUTDOWN_SENTINEL: self._running = False continue if not self._is_enabled: continue # Just drain the queue msg_type, msg_uid, msg_obj = StoreReceiver._read(msg) if not isinstance(msg_obj, list): msg_obj = [msg_obj] for msg in msg_obj: try: self._inflight_cache.append(msg) await self.write(msg) retries = 0 # Reset write retries counter except ( sqlalchemy.exc.NoResultFound, sqlalchemy.exc.MultipleResultsFound, sqlalchemy.exc.IntegrityError, ): # All these mean that the last message was # (1) a metadata message and that (2) some cruft was # left in the DB. We try tro continue, but we must # first pop the offending message: _ = self._inflight_cache.pop() except sqlalchemy.exc.DBAPIError as e: if e.connection_invalidated: retries += 1 continue LOG.critical( "Encountered a fatal error, cannot continue " "to write data to the database: %s", e, ) # We still need to continue to retrieve messages from # the incoming queue, even if we immediately # throw them away afterwards. So we disable ourselves # and continue: self.disable() self._inflight_cache.clear() continue if self._db_session is not None: await self._maybe_commit(force=True) self._db_session.close() self._db_session = None if self._db_engine is not None: self._db_engine.dispose() try: self._incoming_queue.close() except: # Might already be closed, but that's ok. pass LOG.info("%s has shut down.", self)
def _interrupt(self): LOG.info("%s has been interrupted, draining queue.") self._running = False def _terminate(self): self._running = False if self._incoming_task: self._incoming_task.cancel() self._incoming_task = None self._incoming_queue.close() LOG.warning( "%s is being terminated. Input queue is closed, " "messages might be lost. Will try to commit %d messages " "from the inflight queue to the database.", self, len(self._inflight_cache), )
[docs] async def write(self, message): """Main method called to write a message to the buffer.""" if message is None: return coro = self._message_dispatch.get(message.__class__, None) if message.__class__ not in self._message_dispatch or coro is None: StoreReceiver._handle_unknown_message(message) return try: await coro(message) except ( sqlalchemy.exc.NoForeignKeysError, sqlalchemy.exc.ProgrammingError, ) as e: LOG.exception( "%s disables itself since " "the developers are too stupid " "to get the schema right: %s", self, e, ) self.disable()
@staticmethod def _handle_unknown_message(message): if isinstance(message, str): # Python parses some of the heartbeat messages to strings. # This doesn't concern us, but outputting a warning just because # we parsed some random stuff into a str isn't exactly # user-friendly. return LOG.debug( "Store received message %s, but cannot handle it - ignoring", message.__class__, ) async def _write_experiment(self, msg: proto.ExperimentRunStartRequest): from palaestrai.experiment.experiment_run import ExperimentRun experiment_name = msg.experiment_run.experiment_uid or ( "Dummy Experiment record " "for ExperimentRun %s" % msg.experiment_run_id ) query = select(dbm.Experiment).where( dbm.Experiment.name == experiment_name ) experiment_record = self._dbh.execute(query).scalars().first() if not experiment_record: experiment_record = dbm.Experiment(name=experiment_name) yaml = yml.YAML(typ="safe") yaml.register_class(ExperimentRun) yaml.representer.add_representer( RandomState, ExperimentRun.repr_randomstate ) yaml.constructor.add_constructor( "rng", ExperimentRun.repr_randomstate ) self._dbh.add(experiment_record) # Experiment runs are unique regarding their hash, so if there # already exists a run with the same hash as the current # relate the current run instance to that very run query = select(dbm.ExperimentRun).where( dbm.ExperimentRun.hash == msg.experiment_run.hash ) experiment_run_record = self._dbh.execute(query).scalars().first() query = select(dbm.ExperimentRun).where( dbm.ExperimentRun.uid == msg.experiment_run.uid ) result = self._dbh.execute(query).scalars().all() if len(result) > 1: LOG.warning( 'Found %d entries for experiment run "%s" ' "with hash %s " "when there should be at most one. " "I'm going to add your data to the existing one " "(ID in the database: %d), " "but if strange things happen, don't blame it on me.", len(result), msg.experiment_run.uid, msg.experiment_run.hash, result[0].id, ) try: experiment_run_record = result[0] if experiment_run_record.hash != msg.experiment_run.hash: now = datetime.datetime.now() oldname = f"{msg.experiment_run.uid} (before {now})" LOG.error( 'Your experiment run "%s" is already recorded in ' "the database, but with a different hash. I'm going " 'to rename the old version to "%s", ' "but you should really take care of that.", msg.experiment_run.uid, oldname, ) experiment_run_record.uid = oldname self._dbh.add(experiment_run_record) raise IndexError except IndexError: experiment_run_record = dbm.ExperimentRun( uid=msg.experiment_run.uid, document=msg.experiment_run, hash=msg.experiment_run.hash, ) experiment_record.experiment_runs.append(experiment_run_record) # Every time we see an ExperimentRunStartRequest, it means that we # also create a new instance of this run. try: experiment_run_record.experiment_run_instances.append( dbm.ExperimentRunInstance( uid=msg.experiment_run.instance_uid, ) ) await self._maybe_commit(force=True) except sqlalchemy.exc.IntegrityError as e: LOG.warning( "%s encountered a glitch in the Matrix: A record for " "ExperimentRunInstance(uid=%s) was already there! Perhaps " "your environment does not provide enough entropy, or we have " "a resend. I'm going to ignore this error and continue as " "best as I can. (%s)", self, msg.experiment_run.instance_uid, e, ) self._dbh.rollback() raise # Pass on to "write" to clear async def _write_experiment_run_phase( self, message: proto.SimulationStartRequest ): mode = message.experiment_run_phase_configuration.get( "mode", "unknown" ).lower() query = select(dbm.ExperimentRunInstance).where( dbm.ExperimentRunInstance.uid == message.experiment_run_instance_id ) try: experiment_run_instance_record = ( self._dbh.execute(query).scalars().one() ) except sqlalchemy.orm.exc.NoResultFound: LOG.exception( "%s received a %s, but could not find an instance of %s. " "I cannot store information about this phase; expect more " "errors ahead.", self, repr(message), message.experiment_run_instance_id, ) return query = select(dbm.ExperimentRunPhase).where( dbm.ExperimentRunPhase.number == message.experiment_run_phase, dbm.ExperimentRunPhase.experiment_run_instance_id == experiment_run_instance_record.id, dbm.ExperimentRunPhase.uid == message.experiment_run_phase_id, dbm.ExperimentRunPhase.mode == mode, ) query = select(func.count()).select_from(query.subquery()) try: count = self._dbh.execute(query).scalar_one() except sqlalchemy.orm.exc.NoResultFound: LOG.exception( "%s received a %s, but could not count the run phase record." "I cannot determine if the run phase already exists; expect more " "errors ahead.", self, repr(message), ) return if count > 0: LOG.debug( "%s saw run phase record already exists; not going to write it again." ) return LOG.debug( "%s writing new ExperimentRunPhase for " "ExperimentRun(uid=%s, instance_uid=%s).", self, message.experiment_run_id, message.experiment_run_instance_id, ) try: experiment_run_instance_record.experiment_run_phases.append( dbm.ExperimentRunPhase( number=message.experiment_run_phase, experiment_run_instance_id=experiment_run_instance_record.id, uid=message.experiment_run_phase_id, configuration=message.experiment_run_phase_configuration, mode=message.experiment_run_phase_configuration.get( "mode", "unknown" ).lower(), ) ) await self._maybe_commit(force=True) except sqlalchemy.exc.IntegrityError as e: LOG.debug( "%s saw a %s, but got an IntegrityError from the DB (%s). " "I assume multi worker and will ignore this error.", self, repr(message), e, ) self._dbh.rollback() raise # Pass on to "write" to clear async def _write_environment( self, message: proto.EnvironmentSetupResponse ): query = ( sqlalchemy.select( dbm.ExperimentRunInstance, dbm.ExperimentRunPhase ) .join(dbm.ExperimentRunInstance.experiment_run_phases) .where( dbm.ExperimentRunInstance.uid == message.experiment_run_instance_id, dbm.ExperimentRunPhase.number == message.experiment_run_phase, dbm.ExperimentRunPhase.mode == message.mode.name.lower(), ) ) try: result = self._dbh.execute(query).one() except sqlalchemy.exc.MultipleResultsFound: LOG.exception( "StoreReceiver(id=0x%x, uid=%s) " "encountered an EnvironmentSetupResponse(" "experiment_run_id=%s, experiment_run_instance_id=%s, " "experiment_run_phase=%s), " "but there are duplicate entries for this run phase. " "I will not record this environment as I do not know to which " "phase it belongs. " "Expect more errors from the store ahead.", id(self), self.uid, message.experiment_run_id, message.experiment_run_instance_id, message.experiment_run_phase, ) return except sqlalchemy.exc.NoResultFound: LOG.exception( "%s encountered an %s, " "but there is no record of this phase in the store. " "I will not record this environment as I cannot do it; " "expect more errors from the store ahead.", self, repr(message), ) return environment_records = result[dbm.ExperimentRunPhase].environments try: environment_record = dbm.Environment( uid=message.environment_name, worker_uid=message.environment_id, type=message.environment_type, parameters=message.environment_parameters, environment_conductor_uid=message.sender_environment_conductor, ) environment_records.append(environment_record) await self._maybe_commit(force=True) self._known_environments[ ( message.experiment_run_instance_id, message.experiment_run_phase, message.environment_id, ) ] = environment_record.id self._environment_ticks[ ( message.experiment_run_instance_id, message.experiment_run_phase, message.environment_id, ) ] = 0 except sqlalchemy.exc.IntegrityError: LOG.exception( "%s encountered multiple copies of " "Environment(uid=%s) in the database already present " "for experiment_run_instance=%s and " "experiment_run_phase=%s. " "I'm not going to add another one, because I assume " "a multi-worker setup. However, if there are strange " "errors ahead, you may have been warned...", self, message.environment_id, message.experiment_run_instance_id, message.experiment_run_phase, ) self._dbh.rollback() raise # Pass on to "write" to clear async def _write_static_state( self, message: proto.EnvironmentStartResponse ): environment_record_id = self._get_environment_id( experiment_run_instance_id=message.experiment_run_instance_id, experiment_run_phase=message.experiment_run_phase, environment_id=message.sender, ) query = sqlalchemy.select(dbm.Environment).where( dbm.Environment.id == environment_record_id ) try: result = self._dbh.execute(query).one() result[dbm.Environment].static_model = message.static_model await self._maybe_commit(force=True) except sqlalchemy.exc.MultipleResultsFound: LOG.exception( "%s encountered an EnvironmentStartResponse(" "experiment_run_id=%s, experiment_run_instance_id=%s, " "experiment_run_phase=%s), " "but there are duplicate entries for this run phase. " "I will not record the static model of this environment " "as I do not know to which phase it belongs.", self, message.experiment_run_id, message.experiment_run_instance_id, message.experiment_run_phase, ) raise except sqlalchemy.exc.NoResultFound: LOG.exception( "%s encountered an %s, " "but there is no record of this phase in the store. " "I will not record this environment as I cannot do it; " "expect more errors from the store ahead.", self, repr(message), ) raise async def _reset_environment( self, message: proto.EnvironmentResetResponse ): self._environment_ticks[ ( message.experiment_run_instance_id, message.experiment_run_phase, message.sender_environment_id, ) ] = 0 def _get_environment_id( self, experiment_run_instance_id: str, experiment_run_phase: int, environment_id: str, ) -> int: """Retrieves a store record of an environment from cache or DB.""" index_key = ( experiment_run_instance_id, experiment_run_phase, environment_id, ) if index_key not in self._known_environments: query = ( sqlalchemy.select( dbm.ExperimentRunInstance, dbm.ExperimentRunPhase, dbm.Environment, ) .join(dbm.ExperimentRunInstance.experiment_run_phases) .join(dbm.ExperimentRunPhase.environments) .where( dbm.ExperimentRunInstance.uid == experiment_run_instance_id, dbm.ExperimentRunPhase.number == experiment_run_phase, dbm.Environment.worker_uid == environment_id, ) ) result = self._dbh.execute(query).one() self._known_environments[index_key] = result[dbm.Environment].id return self._known_environments[index_key] async def _write_world_state( self, message: proto.EnvironmentUpdateResponse ): try: environment_record_id = self._get_environment_id( experiment_run_instance_id=message.experiment_run_instance_id, experiment_run_phase=message.experiment_run_phase, environment_id=message.sender_environment_id, ) except sqlalchemy.exc.MultipleResultsFound: LOG.exception( "%s found multiple records for the same Environment(uid=%s) " "during %s. " "Duplicates should not occur here; expect more errors ahead.", self, repr(message), ) raise except sqlalchemy.exc.NoResultFound: LOG.exception( "%s found no record for the Environment(uid=%s) " "during %s. " "Was there no environment setup? Expect more errors ahead.", self, message.sender_environment_id, repr(message), ) raise # Add a new world state. We don't use parent.append() here, because # we don't want to end up with a big augmented list... index_key = ( message.experiment_run_instance_id, message.experiment_run_phase, message.sender_environment_id, ) if message.simtime and message.simtime.simtime_ticks: self._environment_ticks[index_key] = message.simtime.simtime_ticks elif index_key not in self._environment_ticks: self._environment_ticks[index_key] = 0 else: self._environment_ticks[index_key] += 1 world_state_record = dbm.WorldState( simtime_ticks=self._environment_ticks[index_key], simtime_timestamp=( message.simtime.simtime_timestamp if message.simtime else None ), walltime=message.walltime, episode=message.episode, done=message.done, state_dump=message.sensors, setpoints=message.setpoints, environment_id=environment_record_id, ) self._dbh.add(world_state_record) await self._maybe_commit(force=message.done) async def _write_agent(self, message: proto.AgentSetupRequest): query = ( sqlalchemy.select( dbm.ExperimentRunInstance, dbm.ExperimentRunPhase ) .join(dbm.ExperimentRunInstance.experiment_run_phases) .where( dbm.ExperimentRunInstance.uid == message.experiment_run_instance_id, dbm.ExperimentRunPhase.number == message.experiment_run_phase, dbm.ExperimentRunPhase.mode == message.mode.name.lower(), ) ) try: result = self._dbh.execute(query).one() except sqlalchemy.exc.MultipleResultsFound: LOG.exception( "StoreReceiver(id=0x%x, uid=%s) " "encountered an AgentSetupRequest(" "experiment_run_id=%s, experiment_run_instance_id=%s, " "experiment_run_phase=%s), " "but there are duplicate entries for this run phase. " "I will not record this agent as I do not know to which " "phase it belongs. " "Expect more errors from the store ahead.", id(self), self.uid, message.experiment_run_id, message.experiment_run_instance_id, message.experiment_run_phase, ) raise except sqlalchemy.orm.exc.NoResultFound: LOG.exception( "StoreReceiver(id=0x%x, uid=%s) " "encountered an AgentSetupRequest(" "experiment_run_id=%s, experiment_run_instance_id=%s, " "experiment_run_phase=%s), " "but there is no record of this phase in the store. " "I will not record this agent as I cannot do it; " "expect more errors from the store ahead.", id(self), self.uid, message.experiment_run_id, message.experiment_run_instance_id, message.experiment_run_phase, ) raise agent_records = result[dbm.ExperimentRunPhase].agents query = ( sqlalchemy.select(dbm.Agent) .join(dbm.ExperimentRunPhase) .where( dbm.ExperimentRunPhase.id == result[dbm.ExperimentRunPhase].id, dbm.Agent.uid == message.receiver_agent_conductor, ) ) already_known = self._dbh.execute(query).scalars().all() if len(already_known) > 0: return # Multiworker, we'll add the muscle later. try: agent_records.append( dbm.Agent( uid=message.receiver_agent_conductor, name=message.muscle_name, configuration=message.configuration, muscles=[], ) ) await self._maybe_commit(force=True) except sqlalchemy.exc.IntegrityError: LOG.exception( "StoreReceiver(id=0x%x, uid=%s) " "encountered multiple copies of Agent(uid=%s) in the database " "already present for experiment_run_instance=%s and " "experiment_run_phase=%s. I'm not going to add another one. " "Expect more strange errors ahead...", id(self), self.uid, message.muscle_name, message.experiment_run_instance_id, message.experiment_run_phase, ) self._dbh.rollback() raise async def _write_muscles(self, message: proto.AgentSetupResponse): query = ( sqlalchemy.select( dbm.ExperimentRunInstance, dbm.ExperimentRunPhase, dbm.Agent, ) .join(dbm.ExperimentRunInstance.experiment_run_phases) .join(dbm.ExperimentRunPhase.agents) .where( dbm.Agent.uid == message.sender_agent_conductor, dbm.ExperimentRunPhase.number == message.experiment_run_phase, dbm.ExperimentRunPhase.mode == message.mode.name.lower(), dbm.ExperimentRunInstance.uid == message.experiment_run_instance_id, ) ) try: record = self._dbh.execute(query).one() except sqlalchemy.exc.MultipleResultsFound: LOG.exception( "StoreReceiver(id=0x%x, uid=%s) " "encountered an AgentSetupResponse(" "agent_conductor_id=%s, rollout_worker_id=%s, " "experiment_run_id=%s, experiment_run_instance_id=%s, " "experiment_run_phase=%s), " "but there are duplicate entries for this run phase. " "I will not record this agent's muscles as I do not know to " "which agent it belongs. " "Expect more errors from the store ahead.", id(self), self.uid, message.sender_agent_conductor, message.rollout_worker_id, message.experiment_run_id, message.experiment_run_instance_id, message.experiment_run_phase, ) raise except sqlalchemy.exc.NoResultFound: LOG.exception( "%s encountered an %s, " "but there is no record of this agent in the store. " "I will not record this agent's muscles as I cannot do it; " "expect more errors from the store ahead.", self, repr(message), ) raise agent_record = record[dbm.Agent] agent_record.muscles.append(message.rollout_worker_id) flag_modified(agent_record, "muscles") # Mutations are not autotracked await self._maybe_commit(force=True) self._known_agents[ ( message.experiment_run_instance_id, message.experiment_run_phase, message.rollout_worker_id, ) ] = agent_record.id def _get_agent_id( self, experiment_run_instance_id: str, experiment_run_phase: int, mode: str, agent_id: str, ) -> int: index_key = ( experiment_run_instance_id, experiment_run_phase, agent_id, ) if index_key not in self._known_agents: query = ( sqlalchemy.select( dbm.ExperimentRunInstance, dbm.ExperimentRunPhase, dbm.Agent, ) .join(dbm.ExperimentRunInstance.experiment_run_phases) .join(dbm.ExperimentRunPhase.agents) .where( dbm.Agent.muscles.cast(Text).contains(agent_id), dbm.ExperimentRunPhase.number == experiment_run_phase, dbm.ExperimentRunPhase.mode == mode, dbm.ExperimentRunInstance.uid == experiment_run_instance_id, ) ) self._known_agents[index_key] = ( self._dbh.execute(query).one()[dbm.Agent].id ) return self._known_agents[index_key] async def _write_muscle_actions(self, message: proto.MuscleUpdateRequest): if ( not message.sensor_readings and not message.unfiltered_setpoints and not message.rewards ): return # This might be the getter for the Brain model -- ignore. try: agent_record_id = self._get_agent_id( experiment_run_instance_id=message.experiment_run_instance_id, experiment_run_phase=message.experiment_run_phase, agent_id=message.sender_rollout_worker_id, mode=message.mode.name.lower(), ) except sqlalchemy.exc.MultipleResultsFound: LOG.exception( "StoreReceiver(id=0x%x, uid=%s) " "encountered an %s, " "but there are duplicate entries for this agent/run phase. " "This agent's inputs will be ignored and not stored, because " "I do not know to which agent it belongs." "Expect more errors from the store ahead.", id(self), self.uid, repr(message), ) raise except sqlalchemy.orm.exc.NoResultFound: LOG.exception( "StoreReceiver(id=0x%x, uid=%s) " "encountered an %s, " "but there is no record of this agent in the store. " "I will not record this agent's inputs as I do not know to " "which agent it might belong. " "Expect more errors from the store ahead.", id(self), self.uid, repr(message), ) raise # Make sure the user only sees the environment's name, not the worker # as we log the rollout worker's internal UID anyways here, so we can # distinguish individual workers: simtimes = message.simtimes try: simtimes = { ( StoreReceiver._SIMTIMES_ENVKEY_RE.search( # type: ignore[union-attr] env_worker_id ).group( 1 ) ): simtime.__getstate__() for env_worker_id, simtime in message.simtimes.items() } except AttributeError as e: LOG.warning( "Could not convert simtimes (%s): %s. Dumping as-is.", message.simtimes, e, ) muscle_action_record = dbm.MuscleAction( agent_id=agent_record_id, rollout_worker_uid=message.sender_rollout_worker_id, walltime=message.walltime, simtimes=simtimes, sensor_readings=message.sensor_readings, # filtered sensor readings (right before muscle's propose_actions) actuator_setpoints=message.unfiltered_setpoints, # unfiltered setpoints (right after muscle's propose_actions) rewards=message.rewards, objective=message.objective, done=message.done, mode=message.mode, episode=message.episode, statistics=message.statistics, ) self._dbh.add(muscle_action_record) await self._maybe_commit(force=message.done) async def _invalidate_cache( self, message: proto.SimulationControllerTerminationResponse ): """Cleans the local cache after a experiment run phase has ended.""" await self._maybe_commit(force=True) if message.flow_control.value < SimulationFlowControl.STOP_PHASE.value: return # Don't clean on restarts! self._environment_ticks = { k: v for k, v in self._environment_ticks.items() if ( k[0] != message.experiment_run_instance_id and k[1] != message.experiment_run_phase ) } self._known_environments = { k: v for k, v in self._known_environments.items() if ( k[0] != message.experiment_run_instance_id and k[1] != message.experiment_run_phase ) } self._known_agents = { k: v for k, v in self._known_agents.items() if ( k[0] != message.experiment_run_instance_id and k[1] != message.experiment_run_phase ) } @staticmethod def _read(msg): """Unpacks a message, filters ignores""" _ = msg.pop(0) empty = msg.pop(0) assert empty == b"" _ = msg.pop(0) # if len(msg) >= 1: # serv_comm = msg.pop(0) if len(msg) > 3: sender = msg.pop(0) empty = msg.pop(0) header = msg.pop(0) LOG.debug( "Ignored message parts: %s, %s, %s", sender, empty, header ) if ( msg[0] == MDP.W_HEARTBEAT or msg[0] == MDP.W_READY or msg[0] == MDP.W_DESTROY ): return "ignore", None, None if len(msg) == 1: # it is a response uid = "" msg_obj = StoreReceiver._deserialize(msg.pop(0)) msg_type = "response" elif len(msg) == 2: uid = StoreReceiver._deserialize(msg.pop(0)) msg_obj = StoreReceiver._deserialize(msg.pop(0)) msg_type = "request" else: uid = "" msg_obj = None msg_type = "error" return msg_type, uid, msg_obj @staticmethod def _deserialize(msg): try: return deserialize([msg]) except Exception as e: LOG.debug( "StoreReceiver received a message '%s', " "which could not be decompressed: %s", msg, e, ) try: msg = str(msg.decode()) return msg except AttributeError: LOG.debug( "StoreReceiver received a message '%s', " "which could not be str-decoded. ", msg, ) return msg def __str__(self): return "StoreReceiver(id=0x%x, uid=%s, uri=%s)" % ( id(self), self.uid, self._store_uri, )