"""This module contains the class :class:`ExperimentRun` that defines
an experiment run and contains all the information needed to execute
it.
"""
from __future__ import annotations
from typing import TYPE_CHECKING, Dict, List, IO, Union, Optional, Any
import collections.abc
import importlib.resources
import io
import logging
import pprint
import uuid
from io import StringIO
from os import PathLike
from pathlib import Path
import copy
from importlib.metadata import (
version as importlib_version,
) # had to be renamed because else it would clash with the ExperimentRun class version
import ruamel.yaml as yml
import sqlalchemy
from numpy.random import RandomState
from ruamel.yaml.constructor import ConstructorError
from ..agent import AgentConductor
from ..environment import EnvironmentConductor
from ..types.mode import Mode
from ..util import seeding
from ..util.dynaloader import load_with_params
from ..util.exception import UnknownModeError, EnvironmentHasNoUIDError
from ..util.syntax_validation import (
SyntaxValidationResult,
SyntaxValidationError,
)
from palaestrai.store import database_model as dbm
from palaestrai.store.session import Session
from sqlalchemy import select
if TYPE_CHECKING:
from palaestrai.simulation import SimulationController
from palaestrai.experiment import TerminationCondition
import sqlalchemy.orm
LOG = logging.getLogger(__name__)
[docs]
class RunDefinitionError(RuntimeError):
def __init__(self, run: ExperimentRun, message):
super().__init__(message)
self.message = message
self.run = run
def __str__(self):
return "%s (%s)" % (self.message, self.run)
[docs]
class ExperimentRun:
"""Defines an experiment run and stores information.
The experiment run class defines a run in palaestrAI. It contains
all information needed to execute the run. With the setup function
the experiment run can be build.
Parameters
----------
"""
SCHEMA_FILE = "run_schema.yaml"
def __init__(
self,
uid: Union[str, None],
seed: Union[int, None],
version: Union[str, None],
schedule: List[Dict],
run_config: dict,
):
if seed is None:
# numpy expects a seed between 0 and 2**32 - 1
self.seed: int = seeding.create_seed(max_bytes=4)
else:
self.seed = seed
self.rng: RandomState = seeding.np_random(self.seed)[0]
if uid is None:
self.uid = f"ExperimentRun-{uuid.uuid4()}"
LOG.warning(
"Experiment run has no uid, please set one to "
"identify it (assign the 'uid' key). Generated: "
"'%s', so that you can find it in the store.",
self.uid,
)
else:
self.uid = uid
palaestrai_version = importlib_version("palaestrai")
if version is None:
self.version = palaestrai_version
LOG.warning(
"No version has been specified. There is no guarantee "
"that this run will be executed without errors. Please "
"set the version (assign the 'version' key) in the run "
"file. Current palaestrAI version is '%s'.",
self.version,
)
elif version != palaestrai_version:
self.version = version
LOG.warning(
"Your palaestrAI installation has version %s but your "
"run file uses version %s, which may be incompatible.",
palaestrai_version,
version,
)
else:
self.version = version
yaml = yml.YAML(typ="safe")
yaml.representer.add_representer(
RandomState, ExperimentRun.repr_randomstate
)
yaml.constructor.add_constructor(
"rng", ExperimentRun.constr_randomstate
)
self.schedule_config = schedule
self.run_config = run_config
self.run_governor_termination_condition: TerminationCondition
self.schedule: List
self._instance_uid = str(uuid.uuid4())
self._canonical_config = None
@property
def instance_uid(self):
"""The unique ID of this particular experiment run instance
As an ::`ExperimentRun` object is transferred via network, stored in
the DB, etc., it still remains the same instance, but it becomes
different objects in memory. This UID identifies it even if it travels
over the network.
Returns
-------
str
The instances unique ID
"""
return self._instance_uid
@property
def canonical_config(self):
if self._canonical_config is None:
self._canonical_config = self._expand_config()
return self._canonical_config
[docs]
def create_subseed(self) -> int:
"""uses the seeded random number generator to create reproducible sub-seeds"""
# number 5000 is arbitrary, for the numpy RandomState, could be any integer between 0 and 2**32 - 1
a = self.rng.randint(0, 5000)
return a
[docs]
@staticmethod
def repr_randomstate(representer, data):
"""Custom serializer and deserializer so we can dump our subseed
Data = rng"""
serializedData = str(data)
return representer.represent_scalar("rng", serializedData)
[docs]
@staticmethod
def constr_randomstate(constructor, node):
value = yml.loader.Constructor().construct_scalar(node)
a = map(int, value.split(" "))
return map(RandomState, a)
def _expand_config(self) -> Dict:
"""Expanding the experiment run phases.
Experiment run definition implements a cascading hierarchy.
"""
canonical_config = dict() # type: Dict[Any, Any]
expanded_schedule = list()
config = dict() # type: Dict[Any, Any]
canonical_config.update({"uid": self.uid})
canonical_config.update({"seed": self.seed})
for (
phase
) in self.schedule_config: # cascade expansion of the schedule phases
phase_name = list(phase.keys())[0]
expanded_schedule.append(
{
phase_name: copy.deepcopy(
update_dict(config, phase[phase_name])
)
}
)
canonical_config.update({"schedule": expanded_schedule})
canonical_config.update({"run_config": self.run_config})
return canonical_config
def _setup_termination_condition(self):
"""Set up the termination condition.."""
rgtc = self.run_config["condition"]
LOG.debug(
"ExperimentRun(id=0x%x, uid=%s) loading RunGovernor "
"TerminationCondition: %s.",
id(self),
self.uid,
rgtc["name"],
)
try:
rgtc = load_with_params(rgtc["name"], rgtc["params"])
except Exception as err:
LOG.critical(
"Could not load termination condition '%s' with params "
"%s for RunGovernor: %s",
rgtc["name"],
rgtc["params"],
err,
)
raise err
self.run_governor_termination_condition = rgtc
def _setup_schedule(self, broker_uri: str):
"""Initialize the run time schedule
Setup the schedule objects.
"""
self.schedule = list()
for num, phase in enumerate(self.canonical_config["schedule"]):
if len(phase) > 1:
raise RunDefinitionError(
self,
(
"Only one phase per phase allowed but "
f"found {len(phase)} phases."
),
)
elif len(phase) < 1:
LOG.warning(
"ExperimentRun(id=0x%x, uid=%s) found empty phase: "
"%d, skipping this one.",
id(self),
self.uid,
num,
)
continue
phase_name = list(phase.keys())[0]
config = phase[phase_name]
agent_configs = dict() # type: Dict[Any,Any]
self.schedule.append(dict())
self.schedule[num]["phase_config"] = config["phase_config"].copy()
self._setup_environment_conductor(
num, phase_name, config, broker_uri
)
self._setup_agent_conductor(num, phase_name, config, agent_configs)
self._setup_simulation_controller(
num, phase_name, config, agent_configs, broker_uri
)
def _setup_environment_conductor(
self, phase_num: int, phase_name: str, config, broker_uri
):
"""Initialize an :class:`~EnvironmentConductor` for current phase."""
for env_config in config["environments"]:
self.schedule[phase_num].setdefault(
"environment_conductors", dict()
)
env_uid = env_config["environment"].get("uid", None)
if env_uid is None or env_uid == "":
LOG.critical(
"ExperimentRun(id=0x%x, uid=%s): One of your "
"environments has no UID configured. Please "
"provide UIDs for all of your environments. "
"PalaestrAI, over and out!",
id(self),
self.uid,
)
raise EnvironmentHasNoUIDError()
ec = EnvironmentConductor(
env_config,
self.create_subseed(),
)
self.schedule[phase_num]["environment_conductors"][ec.uid] = ec
LOG.debug(
"ExperimentRun(id=0x%x, uid=%s) set up %d "
"EnvironmentConductor object(s) for phase %d: '%s'",
id(self),
self.uid,
len(self.schedule[phase_num]["environment_conductors"]),
phase_num,
phase_name,
)
if len(self.schedule[phase_num]["environment_conductors"]) == 0:
raise RunDefinitionError(
self, f"No environments defined for phase {phase_num}."
)
def _setup_agent_conductor(
self,
phase_num: int,
phase_name: str,
config,
agent_configs: Dict[Any, Any],
):
"""Initialize an :class:`~AgentConductor` for current phase."""
for agent_config in config["agents"]:
self.schedule[phase_num].setdefault("agent_conductors", dict())
ac_conf = {key: value for key, value in agent_config.items()}
ac = AgentConductor(
agent_config=ac_conf,
seed=self.create_subseed(),
uid=agent_config["name"],
)
self.schedule[phase_num]["agent_conductors"][ac.uid] = ac
agent_configs[ac.uid] = ac_conf
num_agent_definitions = len(config["agents"])
num_agent_conductors = len(
self.schedule[phase_num]["agent_conductors"]
)
LOG.debug(
"ExperimentRun(id=0x%x, uid=%s) set up %d AgentConductor "
"object(s) for phase %d: '%s'.",
id(self),
self.uid,
num_agent_conductors,
phase_num,
phase_name,
)
if num_agent_conductors == 0:
raise RunDefinitionError(
self, f"No agents defined for phase {phase_num}."
)
if num_agent_conductors != num_agent_definitions:
raise RunDefinitionError(
self,
f"Your experiment run configuration for phase {phase_num} "
f"contains ambiguities: "
f"{num_agent_definitions} agent definitions spawned "
f"{num_agent_conductors} unique agents. "
f"Please check that all agent names are unique.",
)
def _setup_simulation_controller(
self,
phase_num: int,
phase_name: str,
config,
agent_configs: Dict[Any, Any],
broker_uri,
):
"""Initialize a :class:`~SimulationController` for current phase."""
for _ in range(int(config["phase_config"].get("worker", 1))):
self.schedule[phase_num].setdefault(
"simulation_controllers", dict()
)
try:
mode = Mode[
config["phase_config"].get("mode", "train").upper()
]
except KeyError as err:
raise UnknownModeError(err)
if not config["simulation"]["name"].endswith(
"SimulationController"
):
config["simulation"]["name"] += "SimulationController"
sc: SimulationController = load_with_params(
config["simulation"]["name"],
{
"sim_connection": broker_uri,
"rungov_connection": broker_uri,
"agent_conductor_ids": list(
self.schedule[phase_num]["agent_conductors"].keys()
),
"environment_conductor_ids": list(
self.schedule[phase_num][
"environment_conductors"
].keys()
),
"termination_conditions": config["simulation"][
"conditions"
],
"agents": agent_configs,
"mode": mode,
},
)
self.schedule[phase_num]["simulation_controllers"][sc.uid] = sc
LOG.debug(
"ExperimentRun(id=0x%x, uid=%s) set up %d "
"SimulationController object(s) for phase %d: '%s'.",
id(self),
self.uid,
len(self.schedule[phase_num]["simulation_controllers"]),
phase_num,
phase_name,
)
if len(self.schedule[phase_num]["simulation_controllers"]) == 0:
raise RunDefinitionError(
self,
"No simulation controller defined. Either "
"'workers' < 1 or 'name' of key 'simulation' is "
"not available.",
)
[docs]
def setup(self, broker_uri):
"""Set up an experiment run.
Creates and configures relevant actors.
"""
LOG.debug("ExperimentRun(id=0x%x, uid=%s) setup.", id(self), self.uid)
self._setup_termination_condition()
self._setup_schedule(broker_uri)
LOG.info(
"ExperimentRun(id=0x%x, uid=%s) setup complete.",
id(self),
self.uid,
)
[docs]
def environment_conductors(
self, phase=0
) -> Dict[str, EnvironmentConductor]:
return self.schedule[phase]["environment_conductors"]
[docs]
def agent_conductors(self, phase=0):
return self.schedule[phase]["agent_conductors"]
[docs]
def simulation_controllers(self, phase=0):
return self.schedule[phase]["simulation_controllers"]
[docs]
def get_phase_name(self, phase: int):
return list(self.schedule_config[phase].keys())[0]
[docs]
def get_episodes(self, phase: int):
return self.schedule[phase]["phase_config"].get("episodes", 1)
[docs]
def phase_configuration(self, phase: int):
return self.schedule[phase]["phase_config"]
@property
def num_phases(self):
"""The number of phases in this experiment run's schedule."""
return len(self.schedule)
[docs]
def has_next_phase(self, current_phase):
"""Return if this run has a subsequent phase.
Parameters
----------
current_phase: int
Index of the phase that is being executed.
Returns
-------
bool
True if at least one phase is taking place after
the current phase.
"""
return current_phase + 1 < self.num_phases
[docs]
@staticmethod
def check_syntax(
path_or_stream: Union[str, IO[str], PathLike]
) -> SyntaxValidationResult:
"""Checks if the provided experiment configuration conforms
with our syntax.
Parameters
----------
path_or_stream: 1. str - Path to an experiment configuration file
2. Path - Same as above
3. Any text stream
Returns
----------
SyntaxValidationResult:
Custom object that contains the following information:
1. SyntaxValidationResult.is_valid: Whether the provided experiment
is valid or not (::`bool`).
2. SyntaxValidationResult.error_message: Contains ::`None` if the
experiment is valid or the corresponding error message
if it is invalid.
"""
with importlib.resources.path(
__package__, ExperimentRun.SCHEMA_FILE
) as path:
validation_result = SyntaxValidationResult.validate_syntax(
path_or_stream, path
)
return validation_result
[docs]
@staticmethod
def load(str_path_stream_or_dict: Union[str, Path, Dict, IO[str]]):
"""Load an ::`ExerimentRun` object from a serialized representation.
This method serves as deserializing constructor. It takes a
path to a file, a dictionary representation, or a stream and creates
a new ::`ExperimentRun` object from it.
This method also validates the string/stream representation.
Parameters
----------
str_path_stream_or_dict : Union[str, Path, Dict, IO[str]]
* If `str`, it is interpreted as a file path, and the file is
resolved and loaded;
* if `Path`, the same happens as above;
* if `Dict`, the ::`ExperimentRun` object is initialized directly
from the values of the `Dict`;
* if `TextIO`, the method assumes that it is a serialzed
representation of the ::`ExperimentRun` object (e.g., from an
open file stream) and interprets it as YAML (with a prior
syntax/schema check).
Returns
-------
ExperimentRun
An initialized, de-serialized ::`ExperimentRun` object
"""
LOG.debug("Loading configuration from %s.", str_path_stream_or_dict)
# If we get a dict directly, we syntax check nevertheless.
if isinstance(str_path_stream_or_dict, dict):
sio = StringIO()
yml.YAML(typ="safe", pure=True).dump(str_path_stream_or_dict, sio)
str_path_stream_or_dict = sio
if isinstance(str_path_stream_or_dict, (str, Path)):
try:
str_path_stream_or_dict = open(str_path_stream_or_dict, "r")
except OSError as err:
LOG.error("Could not open run configuration: %s.", err)
raise err
# Load from YAML + schema check:
validation_result = ExperimentRun.check_syntax(str_path_stream_or_dict)
if not validation_result:
LOG.error(
"ExperimentRun definition did not schema validate: %s",
validation_result.error_message,
)
raise SyntaxValidationError(validation_result)
try:
str_path_stream_or_dict.seek(0)
conf = yml.YAML(typ="safe", pure=True).load(
str_path_stream_or_dict
)
str_path_stream_or_dict.close()
except ConstructorError as err:
LOG.error("Could not load run configuration: %s.", err)
raise err
finally:
if isinstance(str_path_stream_or_dict, io.TextIOBase):
str_path_stream_or_dict.close()
LOG.debug("Loaded configuration: %s.", conf)
return ExperimentRun(
uid=conf.get("uid", conf.get("id", None)),
seed=conf.get("seed", None),
version=conf.get("version", None),
schedule=conf["schedule"],
run_config=conf["run_config"],
)
[docs]
def save(
self,
experiment_uid: Optional[str] = None,
session: Optional[sqlalchemy.orm.Session] = None,
):
"""Save an ::`ExerimentRun` object to the store.
This method saves an experiment run and adds it to the database.
Connection credentials are taken from the runtime config.
If an ``experiment_uid`` is supplied, then the experiment run is also
associated with it in the database.
A session instance can also be supplied in order ot reuse an open database connection.
Otherwise, a new connection will be opened.
Parameters
----------
experiment_uid : Optional[str]
The unique ID of this particular experiment run instance
session : Optional[Session]
Creates a new, connected database session to run queries on.
"""
_session = session
if not _session:
_session = Session()
if experiment_uid is None:
experiment_uid = (
"Dummy Experiment record " "for ExperimentRun %s" % self.uid
)
query = select(dbm.Experiment).where(
dbm.Experiment.name == experiment_uid
)
experiment_hack_record = _session.execute(query).scalars().first()
if not experiment_hack_record:
experiment_hack_record = dbm.Experiment(name=experiment_uid)
yaml = yml.YAML(typ="safe")
yaml.register_class(ExperimentRun)
yaml.representer.add_representer(
RandomState, ExperimentRun.repr_randomstate
)
yaml.constructor.add_constructor(
"rng", ExperimentRun.repr_randomstate
)
_session.add(experiment_hack_record)
query = select(dbm.ExperimentRun).where(
dbm.ExperimentRun.uid == self.uid
)
result = _session.execute(query).scalars().all()
if len(result) > 1:
LOG.warning(
"Found %d entries for ExperimentRun(uid=%s) "
"when there should be at most one. I'm going to use the first "
"one, but if strange things happen, don't blame it on me.",
len(result),
self.uid,
)
try:
experiment_run_record = result[0]
except IndexError:
experiment_run_record = dbm.ExperimentRun(
uid=self.uid,
document=self,
)
experiment_hack_record.experiment_runs.append(
experiment_run_record
)
finally:
if not session:
_session.close()
try:
_session.commit()
except sqlalchemy.exc.IntegrityError as e:
LOG.exception(
"%s was not possible to run properly: It "
"was not possible to save the runs of Experiment (uid=%s)! Perhaps "
"your environment does not provide enough entropy, or we have "
"a resend. I'm going to ignore this error and continue as "
"best as I can. (%s)",
self.uid,
e,
)
finally:
if not session:
_session.close()
if not session:
_session.close()
[docs]
def update_dict(src, upd):
"""Recursive update of dictionaries.
See stackoverflow:
https://stackoverflow.com/questions/3232943/
update-value-of-a-nested-dictionary-of-varying-depth
"""
for key, val in upd.items():
if isinstance(val, collections.abc.Mapping):
src[key] = update_dict(src.get(key, {}), val)
else:
src[key] = val
return src
def __repr__(self):
return pprint.pformat(self.canonical_config, indent=4)