from __future__ import annotations
import io
from typing import TYPE_CHECKING
import ruamel.yaml
import sqlalchemy as sa
import sqlalchemy.dialects.postgresql
from sqlalchemy import event, func, Enum, Index, UniqueConstraint
from sqlalchemy.engine import Engine
from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy.orm import relationship
from sqlite3 import Connection as SQLite3Connection
from palaestrai.store.database_base import Base
from palaestrai.types import Mode
from palaestrai.types import ExperimentRunInstanceStatus
if TYPE_CHECKING:
import palaestrai.experiment
yaml = ruamel.yaml.YAML(typ="safe")
@event.listens_for(Engine, "connect")
def _set_sqlite_pragma(dbapi_connection, connection_record):
if isinstance(dbapi_connection, SQLite3Connection):
cursor = dbapi_connection.cursor()
cursor.execute("PRAGMA foreign_keys=ON;")
# WAL lets a reader (e.g. the REST API) and a writer (e.g. the log
# handler) work on the same database file concurrently; busy_timeout
# makes a connection wait instead of immediately raising
# "database is locked"; synchronous=NORMAL is the recommended,
# crash-safe pairing with WAL. journal_mode=WAL is a harmless no-op on
# a pure ``:memory:`` connection, so this is safe for tests.
cursor.execute("PRAGMA journal_mode=WAL;")
cursor.execute("PRAGMA busy_timeout=5000;")
cursor.execute("PRAGMA synchronous=NORMAL;")
cursor.close()
[docs]
class Experiment(Base):
"""A whole experiment, including Design of Experiments
Experiments are the master objects of the palaestrAI store. Experiments
define a study. This includes variations over parameters as the user
wishes. An experiment spawns any number of concrete ::`ExperimentRun`
objects.
"""
__tablename__ = "experiments"
id = sa.Column(sa.INTEGER, primary_key=True, unique=True, index=True)
# The experiment name (the experiment document's ``uid``) must be unique:
# it is the natural key used by the REST API to address an experiment
# (e.g., ``GET /experiments/{name}``). The pre-existing Alembic migration
# ``728eceef1424_add_experiment_name`` already created this column as a
# unique, indexed ``VARCHAR(255)``; this declaration brings the ORM in
# line with it (and tightens it to ``NOT NULL``).
name = sa.Column(sa.String(255), nullable=False, unique=True, index=True)
_document = sa.Column("document", sa.TEXT)
_document_json = sa.Column(
"document_json",
sa.JSON().with_variant(
sqlalchemy.dialects.postgresql.JSONB(), "postgresql"
),
)
experiment_runs = relationship(
"ExperimentRun",
back_populates="experiment",
cascade="all, delete",
passive_deletes=True,
)
@hybrid_property
def document(self):
return self._document_json
@document.setter # type:ignore[no-redef]
def document(self, experiment):
self._document_json = experiment
self._document = repr(experiment)
def __str__(self):
return '<Experiment(id=%s, name="%s")>' % (self.id, self.name)
[docs]
class ExperimentRun(Base):
"""A concrete experiment run created from an experiment
An experiment run is a concrete instance of an experiment. In it, any
parameter variation is replaced by actual parameter settings. An
experiment can spawn as many experiment runs as the user wishes. I.e., an
experiment run is a concrete configuration.
"""
__tablename__ = "experiment_runs"
id = sa.Column(sa.Integer, primary_key=True, unique=True, index=True)
uid = sa.Column(sa.String(255), index=True)
hash = sa.Column(sa.String(255), index=True)
experiment_id = sa.Column(
sa.Integer,
sa.ForeignKey("experiments.id", ondelete="CASCADE"),
index=True,
)
_document = sa.Column("document", sa.TEXT)
_document_json = sa.Column(
"document_json",
sa.JSON().with_variant(
sqlalchemy.dialects.postgresql.JSONB(), "postgresql"
),
)
experiment = relationship("Experiment", back_populates="experiment_runs")
experiment_run_instances = relationship(
"ExperimentRunInstance",
back_populates="experiment_run",
cascade="all, delete",
passive_deletes=True,
)
__table_args__ = (UniqueConstraint("experiment_id", "uid"),)
@hybrid_property
def document(self) -> palaestrai.experiment.ExperimentRun:
return self._document_json
@document.setter # type:ignore[no-redef]
def document(self, experiment_run: palaestrai.experiment.ExperimentRun):
er_dict = experiment_run.__getstate__()
del er_dict["_rng"]
del er_dict["_instance_uid"]
self._document_json = er_dict
# NOTE: We don't need to register the class here anymore,
# since the ExperimentRun can now supply us with its dict state via
# __getstate__().
sio = io.StringIO()
yaml.dump(er_dict, sio)
self._document = sio.getvalue()
def __str__(self):
return (
'<ExperimentRun(id=%s, uid="%s", experiment_id=%s, '
"document=%s>"
% (self.id, self.uid, self.experiment_id, self.document)
)
[docs]
class ExperimentRunInstance(Base):
"""An execution of an experiment run
Each experiment run can be executed as many times as a user wishes.
This does not change its outcome, but for reproducibility, such re-runs
are sensible. When an experiment run is actually executed - the experiment
run being the blue print of an actual execution -, an experiment run
instance is created.
"""
__tablename__ = "experiment_run_instances"
id = sa.Column(sa.Integer, primary_key=True, unique=True, index=True)
uid = sa.Column(sa.String(196), unique=True, index=True)
created_at = sa.Column(sa.DateTime, default=func.now())
# Lifecycle status of this concrete execution. Owned/written by the
# Executor while the ``palaestrai serve`` service runs the instance, and
# read by the REST API (``GET /experiment_run_instances/{uid}``). A freshly
# created instance starts out as ``SCHEDULED``. There is deliberately no
# ``UNKNOWN`` value: that is the API's response when no row exists at all.
status = sa.Column(
Enum(
ExperimentRunInstanceStatus,
name="experiment_run_instance_status_enum",
values_callable=lambda e: [m.value for m in e],
# Coerce to a native ENUM in PostgreSQL and a VARCHAR in SQLite,
# mirroring how ``MuscleAction.mode`` is handled.
native_enum=True,
create_type=True,
),
nullable=False,
default=ExperimentRunInstanceStatus.SCHEDULED,
server_default=ExperimentRunInstanceStatus.SCHEDULED.value,
)
experiment_run_id = sa.Column(
sa.Integer,
sa.ForeignKey(ExperimentRun.id, ondelete="CASCADE"),
index=True,
)
experiment_run = relationship(
"ExperimentRun", back_populates="experiment_run_instances"
)
experiment_run_phases = relationship(
"ExperimentRunPhase",
back_populates="experiment_run_instance",
cascade="all, delete",
passive_deletes=True,
)
[docs]
class ExperimentRunPhase(Base):
__tablename__ = "experiment_run_phases"
id = sa.Column(sa.INTEGER, primary_key=True, unique=True, index=True)
uid = sa.Column(sa.String(255), index=True, nullable=False)
number = sa.Column(sa.INTEGER, nullable=False)
mode = sa.Column(sa.String(128), nullable=True)
configuration = sa.Column(
"configuration",
sa.JSON().with_variant(
sqlalchemy.dialects.postgresql.JSONB(), "postgresql"
),
nullable=True,
)
experiment_run_instance_id = sa.Column(
sa.Integer,
sa.ForeignKey(ExperimentRunInstance.id, ondelete="CASCADE"),
index=True,
)
experiment_run_instance = relationship(
"ExperimentRunInstance", back_populates="experiment_run_phases"
)
environments = relationship(
"Environment",
back_populates="experiment_run_phase",
cascade="all, delete",
passive_deletes=True,
)
agents = relationship(
"Agent",
back_populates="experiment_run_phase",
cascade="all, delete",
passive_deletes=True,
)
__table_args__ = (
sa.UniqueConstraint(
"uid", "number", "mode", "experiment_run_instance_id"
),
)
[docs]
class Environment(Base):
__tablename__ = "environments"
id = sa.Column(sa.Integer, primary_key=True, unique=True, index=True)
uid = sa.Column(sa.String(255), nullable=False, index=True)
worker_uid = sa.Column(sa.String(255), nullable=False, index=True)
environment_conductor_uid = sa.Column(sa.String(255), nullable=False)
type = sa.Column(sa.String(255), nullable=True)
parameters = sa.Column("parameters", sa.JSON, nullable=True)
static_model = sa.Column(
"static_model",
sa.JSON().with_variant(
sqlalchemy.dialects.postgresql.JSONB(), "postgresql"
),
nullable=True,
)
experiment_run_phase_id = sa.Column(
sa.Integer,
sa.ForeignKey(ExperimentRunPhase.id, ondelete="CASCADE"),
index=True,
)
experiment_run_phase = relationship(
"ExperimentRunPhase", back_populates="environments"
)
world_states = relationship(
"WorldState",
back_populates="environment",
cascade="all, delete",
passive_deletes=True,
)
__table_args__ = (
sa.UniqueConstraint("uid", "worker_uid", "experiment_run_phase_id"),
)
def __str__(self):
return (
f'<Environment(id={self.id}, uid="{self.uid}", type="'
f'{self.type}", parameters=({len(self.parameters)} chars))>'
)
[docs]
class WorldState(Base):
__tablename__ = "world_states"
id = sa.Column(
sa.Integer,
autoincrement=True,
primary_key=True,
unique=True,
index=True,
)
walltime = sa.Column(
sa.TIMESTAMP(timezone=True),
default=func.now(),
primary_key=False,
nullable=False,
)
simtime_ticks = sa.Column(sa.Integer)
simtime_timestamp = sa.Column(sa.TIMESTAMP)
episode = sa.Column("episode", sa.Integer, default=1)
state_dump = sa.Column(
"state_dump",
sa.JSON().with_variant(
sqlalchemy.dialects.postgresql.JSONB(), "postgresql"
),
)
setpoints = sa.Column(
"setpoints",
sa.JSON().with_variant(
sqlalchemy.dialects.postgresql.JSONB(), "postgresql"
),
nullable=True,
)
done = sa.Column(
sa.Boolean,
unique=False,
nullable=False,
default=bool(False),
)
environment_id = sa.Column(
sa.Integer,
sa.ForeignKey(Environment.id, ondelete="CASCADE"),
index=True,
)
environment = relationship("Environment", back_populates="world_states")
def __str__(self):
return (
f"<WorldState id={self.id}, "
f"walltime={self.walltime}, "
f"simtime_ticks={self.simtime_ticks}, "
f"simtime_timestamp={self.simtime_timestamp} "
f"done={self.done}>"
)
[docs]
class Agent(Base):
__tablename__ = "agents"
id = sa.Column(sa.Integer, primary_key=True, unique=True, index=True)
uid = sa.Column(sa.String(255), nullable=False, index=True)
name = sa.Column(sa.String(255), nullable=True)
muscles = sa.Column(
"muscles",
sa.JSON().with_variant(
sqlalchemy.dialects.postgresql.JSONB(), "postgresql"
),
nullable=False,
default=list(),
)
configuration = sa.Column(
"configuration",
sa.JSON().with_variant(
sqlalchemy.dialects.postgresql.JSONB(), "postgresql"
),
nullable=True,
)
experiment_run_phase_id = sa.Column(
sa.Integer,
sa.ForeignKey(ExperimentRunPhase.id, ondelete="CASCADE"),
nullable=False,
)
experiment_run_phase = relationship(
"ExperimentRunPhase", back_populates="agents"
)
brain_states = relationship(
"BrainState",
back_populates="agent",
cascade="all, delete",
passive_deletes=True,
)
muscle_actions = relationship(
"MuscleAction",
order_by="MuscleAction.id",
back_populates="agent",
cascade="all, delete",
passive_deletes=True,
)
__table_args__ = (sa.UniqueConstraint("uid", "experiment_run_phase_id"),)
[docs]
class BrainState(Base):
__tablename__ = "brain_states"
id = sa.Column(
sa.Integer,
autoincrement=True,
primary_key=True,
unique=True,
index=True,
)
walltime = sa.Column(
sa.TIMESTAMP(timezone=True),
default=func.now(),
primary_key=False,
nullable=False,
)
state = sa.Column(sa.LargeBinary, nullable=True)
tag = sa.Column(sa.String(96), index=True)
simtime_ticks = sa.Column(sa.Integer, nullable=True)
simtime_timestamp = sa.Column(sa.TIMESTAMP, nullable=True)
agent_id = sa.Column(
sa.Integer, sa.ForeignKey(Agent.id, ondelete="CASCADE"), index=True
)
agent = relationship("Agent", back_populates="brain_states")
# Composite index definition
__table_args__ = (
Index(
"ix_brain_states_agent_tag_id", # index name
"agent_id",
"tag",
"id",
postgresql_using="btree",
postgresql_ops={"id": "DESC"}, # important: descending index on id
),
)
[docs]
class MuscleAction(Base):
__tablename__ = "muscle_actions"
id = sa.Column(
sa.Integer,
autoincrement=True,
primary_key=True,
unique=True,
index=True,
)
walltime = sa.Column(
sa.TIMESTAMP(timezone=True),
default=func.now(),
primary_key=False,
nullable=False,
)
agent_id = sa.Column(
sa.Integer, sa.ForeignKey(Agent.id, ondelete="CASCADE"), index=True
)
episode = sa.Column("episode", sa.Integer, default=1)
simtimes = sa.Column(
"simtimes",
sa.JSON().with_variant(
sqlalchemy.dialects.postgresql.JSONB(), "postgresql"
),
nullable=False,
default=list(),
)
sensor_readings = sa.Column(
"sensor_readings",
sa.JSON().with_variant(
sqlalchemy.dialects.postgresql.JSONB(), "postgresql"
),
nullable=True,
)
actuator_setpoints = sa.Column(
"actuator_setpoints",
sa.JSON().with_variant(
sqlalchemy.dialects.postgresql.JSONB(), "postgresql"
),
nullable=True,
)
rewards = sa.Column(
"rewards",
sa.JSON().with_variant(
sqlalchemy.dialects.postgresql.JSONB(), "postgresql"
),
nullable=True,
)
objective = sa.Column("objective", sa.Float, default=0.0)
done = sa.Column("done", sa.Boolean, default=False)
mode = sa.Column(
Enum(
Mode,
name="mode_enum",
values_callable=lambda e: [m.value for m in e],
# By setting native_enum=True, sqlalchemy coerces the enum type
# to an appropriate type in the backend, i.e., to an actual ENUM
# type in postgres and a VARCHAR in sqlite3
native_enum=True,
create_type=True,
)
)
statistics = sa.Column(
"statistics",
sa.JSON().with_variant(
sqlalchemy.dialects.postgresql.JSONB(), "postgresql"
),
nullable=True,
)
rollout_worker_uid = sa.Column(
"rollout_worker_uid", sa.String(255), default=None
)
agent = relationship("Agent", back_populates="muscle_actions")
def __str__(self):
return (
f"<MuscleAction(id={self.id}, "
f"agent_id={self.agent_id}, "
f"walltime={self.walltime}, "
f"episode={self.episode}, "
f"done={self.done}, "
f"mode={self.mode}, "
f"simtimes={self.simtimes}, "
f"sensor_readings={self.sensor_readings}, "
f"actuator_setpoints={self.actuator_setpoints}, "
f"rewards={self.rewards}>"
)
Model = Base