from __future__ import annotations
import os
import signal
import weakref
import asyncio
import logging
import inspect
import multiprocessing
import aiomultiprocess
from collections.abc import Iterable
from typing import Dict, Set, Tuple, Any, Union, Callable, Optional
from palaestrai.core import RuntimeConfig, MajorDomoClient, MajorDomoWorker
LOG = logging.getLogger(__name__)
[docs]
class EventStateMachine:
"""An event-triggered state machine
The EventStateMachine (ESM) can be used to transparently handle events
within palaestrAI. An ESM wraps another class and callbacks can be defined
with method decorators for events. Events are:
* A message received,
* a signal received (SIGCHLD, SIGTERM, etc.)
* setup
* enter (the initial event)
* teardown
The initial event *enter* is issued immediately after the main event/state
loop commences in order to provide an entrypoint for operation.
The *enter* event can be used to, e.g., send out the first request.
For example::
@ESM.monitor()
class Foo:
@ESM.enter
async def _enter(self):
_ = await self._request_initialization()
@ESM.requests
async def _request_initialization(self):
# ...
return InitRequest(
# ...
)
It is not strictly necessary to provide an *enter* event.
If the monitored class is exclusively an MDP worker, then there is no need
for the *enter* event, because the worker reacts on the first request it
receives and not on its own volition.
In order to make a class use the ESM, you must decorate it with
::`~.monitor`. The ``monitor`` decorator can also inject all necessary
code to handle ZMQ MDP workers.
If the monitored class does not have a ``run`` method, the ESM will also
inject it. The ``run`` method then serves as an event/state loop that
continues until it is stopped. At the start of the ``run`` method, the
target objects ``setup`` method is called if it exists. Likewise,
a ``teardown`` method will be called immediately after the loop ends.
The ESM also adds a ``stop`` method to the target object. It serves to
terminate the event/state loop.
In order to react to a specific event, users of the ESM can decorate their
methods with ``on(event)``. The ::`~.on` decorator takes as parameter
the class of what is handled. E.g., the class of a particular message, or
``signal.SIGCHLD`` to react to a process that has ended. For example::
from palaestrai.core import EventStateMachine as ESM
import signal
@ESM.monitor()
class Foo:
@ESM.on(SomeRequest)
async def handle_some_request(self, request):
# ...
pass
@ESM.on(signal.SIGCHLD)
async def handle_process_termination(self, process):
# ...
pass
Spawning processes is also handled through a decorator: ``spawns``. If a
method decorated with ``spawns`` returns a ::`Process` object, this
process will automatically be monitored. E.g.,::
# ...
@ESM.spawns
def start_some_fancy_process(self):
p = multiprocessing.Process(target=somefunc)
p.start()
return p
The ESM also handles the sending of requests. ESM-monitored classes do not
need to instantiate and monitor MDP client objects themselves. Instead,
they simply need methods to be decorated with ``requests``. The so
decorated method must return a message object that has the ``receiver``
property, so that ::`~.requests` can handle sending. E.g.,::
# ...
@ESM.requests
def get_something_from_a_worker(self):
req = SomeRequest()
req.receiver = "Foo"
return req
@ESM.on(SomeResponse) # also handle the response!
def handle_response_from_worker(self, response):
# ...
pass
The ESM also supports classes that act as workers. For this, the ESM's
``monitor`` decorator needs the flag ``is_mdp_worker=True``. Then, the
ESM injects the property ``mdp_service``. Setting this property connects
the MDP worker, and ``ESM.on`` can be used to handle requests from clients.
For example::
@ESM.monitor(is_mdp_worker=True)
class Foo:
async def setup(self):
self.mdp_service = "Foo"
@ESM.on(SomeRequest)
def handle_request_from_client(self, req):
do_something_with(request)
rsp = SomeResponse()
rsp.receiver = req.sender
return rsp
"""
_decorated_methods: Dict[Callable, Any] = dict()
_monitored_objects: Dict[Tuple[int, weakref.ref], EventStateMachine] = (
dict()
)
@staticmethod
def _cleanup_monitored_objects(ref: weakref.ref):
pid = os.getpid()
LOG.debug("EventStateMachine cleaning %s", (pid, ref))
del EventStateMachine._monitored_objects[(pid, ref)]
@staticmethod
def esm_for(monitored: Any) -> EventStateMachine:
"""Returns the ESM instance for any monitored object.
This method retrieves the ESM instance responsible for a monitored
object. It does not check whether the object has been decoreted with
::`~.monitored`, though.
Parameters
----------
monitored : Any
A monitored object
Returns
-------
EventStateMachine
The ESM instance responsible for the monitored object. A new
instance will be created if it does not already exist.
"""
pid = os.getpid()
ref = weakref.ref(
monitored, EventStateMachine._cleanup_monitored_objects
)
try:
esm = EventStateMachine._monitored_objects[(pid, ref)]
except KeyError:
esm = EventStateMachine(monitored)
EventStateMachine._monitored_objects[(pid, ref)] = esm
return esm
@staticmethod
def _make_mdp_service_property():
@property
def mdp_service(self) -> str:
return self.__esm__._mdp_worker_service
@mdp_service.setter
def mdp_service(self, value: str):
self.__esm__._mdp_worker_service = value
self.__esm__._connect_worker_and_listen()
return mdp_service
@staticmethod
def monitor(is_mdp_worker=False):
"""Decorates a class to monitor instances of it with the ESM.
This decorator is the minimal required decoration of any class that
makes use of the ESM. It injects the ESM instance into new objects of
that class, and also adds relevant methods. The usage is::
from palaestrai.core import EventStateMachine as ESM
@ESM.monitor()
class Foo:
pass
@ESM.monitor(is_mdp_worker=True)
class Bar:
pass
The ``@monitor`` decorator injects methods to the target class, namely:
* ``run()``: The default run method that kicks off the event/state
loop of the target class.
* ``stop()``: Stops the event/state loop of the class and can be
called from any handler.
If ``is_mdp_worker=True`` was given, then the ESM also takes care of
handling MDP requests for the target class. Then, another property
is injected: ``mdp_service``. This is then the service name the worker
will listen on. Setting the property instanciates a ::`MajorDomoWorker`
and connects it to the broker.
Parameters
----------
is_mdp_worker : bool
If ``True``, the monitored class will act as MDP worker. The ESM
will inject a property ``mdp_service``. Setting the property will
create a ::`MajorDomoWorker` instance and connect it to the broker.
"""
def _wraps(clazz):
attrs = dir(clazz)
setattr(
clazz,
"__esm__",
property(lambda self: EventStateMachine.esm_for(self)),
)
if is_mdp_worker:
setattr(
clazz,
"mdp_service",
EventStateMachine._make_mdp_service_property(),
)
if "run" not in attrs:
setattr(clazz, "run", EventStateMachine.run)
setattr(clazz, "stop", EventStateMachine.stop)
return clazz
return _wraps
@staticmethod
def on(sig_or_msg_or_str):
"""Register an event/state transition handler.
``on`` is a decorator used to register handlers for any kind of event.
Typical usage is::
from palaestrai.core import EventStateMachine as ESM
@ESM.monitor()
class Foo:
@ESM.on(SomeRequest)
def bar(self, req):
pass # ...
Typical arguments to ``on`` are:
* A class: When the ESM receives an MDP request or response, it will
check whether the message's class has a handler registered. The
registered method is then called and the message passed.
* An exception class: The handler is triggered when the exception is
thrown.
* A signal: Handles signals such as ``SIGCHLD`` (when a child process
terminates), ``SIGINT``, or ``SIGTERM``.
"""
try:
sig_or_msg_or_str = sig_or_msg_or_str.__name__
except AttributeError:
pass
def _register_func(func):
EventStateMachine._decorated_methods[func] = sig_or_msg_or_str
return func
return _register_func
@staticmethod
def enter(func):
"""Decorates a method to be the very first state
Each state machine needs an initial state; the ESM is no exception. A
method decorated with ``enter`` is called immediately at the beginning
of the event/state loop.
Usage example::
@ESM.monitor()
class Foo:
@ESM.enter
def _enter(self):
pass # Do something, like launching a process.
``enter`` is not used for setup purposes: If the target class has a
``setup`` method, this one is called immediately *before* the
event/state loop commences. Thus, the enter method is optional. E.g.,
a class that simply acts as MDP worker does not need it; it is
sufficient to set the MDP service name in the ``setup`` method.
"""
EventStateMachine._decorated_methods[func] = "ENTER"
return func
def _handle_enter(self):
pass # Intentional a noop to suppress a warning for the ENTER event.
def _handle_terminated_child(
self, process: Union[aiomultiprocess.Process, multiprocessing.Process]
):
LOG.debug(
"%s saw termination of process: %s. No other handler is "
"installed.",
self,
process,
)
if process.exitcode != 0:
LOG.error(
"Process %s died with exit code %d",
process,
process.exitcode,
)
def _handle_sigint(self):
LOG.debug("%s handles SIGINT for %s", self, self._monitored)
self.stop(self._monitored)
def _handle_sigterm(self):
LOG.debug("%s handles SIGTERM for %s", self, self._monitored)
self.stop(self._monitored)
@staticmethod
def spawns(func):
"""Signify that a method creates (spawns) new sub-processes
Child processes are also monitored by the ESM. In order to find out
which processes to monitor, the ESM checks the return values of all
methods that are decorated with ``@spawns``. For example::
from palaestrai.core import EventStateMachine as ESM
@ESM.monitor()
class Foo:
@ESM.spawns
async def some_method(self):
p = multiprocessing.Process(target=foofunc)
p.start()
return p
@ESM.enter
def _enter(self):
_= await self.spawns()
**Note:** Processes that are returned from a spawning function are not
automatically started, just monitored.
"""
def _wraps(self, *args, **kwargs):
ret = func(self, *args, **kwargs)
for process in [
x
for x in (ret if isinstance(ret, Iterable) else [ret])
if isinstance(x, multiprocessing.Process)
or isinstance(x, aiomultiprocess.Process)
]:
self.__esm__.monitor_process(process)
return ret
return _wraps
@staticmethod
def requests(func):
"""Signify that the returned request message object awaits an answer.
When the object monitored by the ESM sends out requests, it will
also want to react to responses. In order to manage tracking of
requests and responses, the ESM uses the ``requests`` method decorator.
For example,::
from palaestrai.core import EventStateMachine as ESM
@ESM.monitor()
class Foo:
@ESM.requests
def send_some_request(self):
# ...
return SomeRequest(receiver="SomeWorker")
@ESM.on(SomeResponse)
def handle_some_response(self, response):
# ...
pass
The message object's class needs to end with ``Request``. It can be
passed along with other objects as well. So if the method returns
a tuple or a list, the ESM will inspect each object to be a message
object, and track that.
"""
def _wraps(self, *args, **kwargs):
ret = func(self, *args, **kwargs)
for mdp_request in [
x
for x in (ret if isinstance(ret, Iterable) else [ret])
if type(x).__name__.endswith("Request")
]:
self.__esm__._tasks.add(
asyncio.create_task(self.__esm__.send_request(mdp_request))
)
return ret
return _wraps
@staticmethod
async def run(monitored):
"""Main event/state loop of the ESM
This ``run`` method is injected into monitored classes if they do not
have one already. The structure of ``run`` is as follows:
1. It resets the handlers for SIGCHLD, SIGINT, and SIGTERM to the OS'
default.
2. It calls ``monitored.setup()``, if it exists.
3. It creates an ESM instance for the monitored object and adds signal
handlers for SIGCHLD, SIGINT, and SIGTERM according to what the
monitored class defines (via ``@ESM.on(signal.SIGINT)``, etc.)
4. It transides to the first state, defined by ``@ESM.enter``. It then
waits for state changes/events until ``monitored.stop()`` is called.
5. Finally, once the main event/state loop concludes,
``monitored.teardown()`` is called (if present).
"""
signal.signal(signal.SIGCHLD, signal.SIG_DFL)
signal.signal(signal.SIGINT, signal.SIG_DFL)
signal.signal(signal.SIGTERM, signal.SIG_DFL)
if "setup" in dir(monitored):
LOG.debug("Running %s.setup()...", monitored)
try:
if asyncio.iscoroutinefunction(monitored.setup):
await monitored.setup()
else:
monitored.setup()
except Exception as e:
LOG.exception("%s.setup() failed with %s", monitored, e)
return
esm = EventStateMachine.esm_for(monitored)
asyncio.get_running_loop().add_signal_handler(
signal.SIGINT,
lambda loop: loop.create_task(esm._handle_event(signal.SIGINT)),
)
asyncio.get_running_loop().add_signal_handler(
signal.SIGTERM,
lambda loop: loop.create_task(esm._handle_event(signal.SIGTERM)),
)
esm._future = asyncio.get_running_loop().create_future()
LOG.debug("%s commencing loop: Waiting for my own futureā¦", esm)
asyncio.create_task(esm._handle_event("ENTER"))
await esm._future
LOG.debug("%s: The future is now!", esm)
if esm._future.exception() is not None:
try:
raise esm._future.exception()
except Exception as e:
LOG.exeption("%s terminated with an exception: %s", esm, e)
if "teardown" in dir(monitored):
try:
if asyncio.iscoroutinefunction(monitored.teardown):
await monitored.teardown()
else:
monitored.teardown()
except Exception as e:
LOG.exception("%s.teardown() failed with %s", monitored, e)
await esm._cleanup()
@staticmethod
def stop(monitored):
"""Stops the ESM.
Stopping the ESM also means shutting down all running processes and
cancelling all outstanding tasks (e.g., request monitors).
"""
esm = EventStateMachine.esm_for(monitored)
esm._stop()
def __init__(self, monitored: Any):
self._future: asyncio.Future
self._monitored = monitored
self._monitored_processes: Dict[
Union[aiomultiprocess.Process, multiprocessing.Process],
asyncio.Task,
] = dict()
self._tasks: Set[asyncio.Task] = {
asyncio.create_task(self._watch_tasks(), name="Tasks Watcher")
}
self._handlers = {
signal.SIGCLD: self._handle_terminated_child,
signal.SIGINT: self._handle_sigint,
signal.SIGTERM: self._handle_sigterm,
"ENTER": self._handle_enter,
Any: None,
}
self._mdp_worker_service: Optional[str] = None
self._mdp_worker: Optional[MajorDomoWorker] = None
self._mdp_client: Optional[MajorDomoClient] = None
self._mdp_client_lock = asyncio.Lock()
# Update handlers, match to methods of the _monitored object:
injected_methods = [ # These are injected by us, ignore them here:
"__esm__",
"mdp_service",
"run",
]
directory = dir(self._monitored)
all_attributes = [
getattr(self._monitored, x)
for x in directory
if not x in injected_methods
]
self._handlers.update(
{
EventStateMachine._decorated_methods[x.__func__]: x
for x in all_attributes
if inspect.ismethod(x)
and x.__func__ in EventStateMachine._decorated_methods
}
)
self._handlers.update(
{
EventStateMachine._decorated_methods[x]: x
for x in all_attributes
if inspect.isfunction(x)
and x in EventStateMachine._decorated_methods
}
)
async def _watch_tasks(self):
while self._tasks:
done, pending = await asyncio.wait(
self._tasks, return_when=asyncio.FIRST_COMPLETED
)
exceptionals = [t for t in done if t.exception() is not None]
for e in exceptionals:
LOG.error(
"%s saw task %s raise exception: %s",
self,
e,
e.exception(),
)
await self._handle_event(
type(e.exception()).__name__, e.exception()
)
self._tasks = pending
async def _handle_event(self, event: Any, *args, **kwargs) -> Any:
try:
handler: Any = self._handlers[event]
except KeyError:
handler = self._handlers[Any]
if handler is None: # Default handler
LOG.warning("%s has no handler for %s", self._monitored, event)
return
try:
if asyncio.iscoroutinefunction(handler):
return await handler(*args, **kwargs)
else:
return handler(*args, **kwargs)
except Exception as e:
LOG.exception(
"%s encountered exception from the handler for %s", self, event
)
# Perhaps there is a handler for the exception...?
# Except, of course, we're already trying to handle the
# exception...
if (
not isinstance(event, Exception)
and type(e).__name__ in self._handlers
):
await self._handle_event(type(e).__name__, e)
else:
self._future.set_exception(e)
def monitor_process(
self, process: Union[aiomultiprocess.Process, multiprocessing.Process]
):
task = asyncio.create_task(
self._watch_process(process),
name=f"Process watcher for child {process.pid}",
)
self._monitored_processes[process] = task
LOG.debug("%s now monitors process %s", self, process)
async def _watch_process(
self, process: Union[aiomultiprocess.Process, multiprocessing.Process]
):
LOG.debug("%s starts to watch process: %s", self, process)
if isinstance(process, aiomultiprocess.Process):
await process.join()
else:
process.join()
LOG.debug(
"%s saw a process end: %s, calling handler...", self, process
)
await self._handle_event(signal.SIGCHLD, process)
del self._monitored_processes[process] # Cleanup.
@property
def mdp_client(self):
if self._mdp_client is None:
self._mdp_client = MajorDomoClient(
f"tcp://127.0.0.1:{RuntimeConfig().executor_bus_port}"
)
return self._mdp_client
async def _wait_for_response(self, service: str, request: Any):
try:
await self._mdp_client_lock.acquire()
resp = await self.mdp_client.send(service, request)
except Exception as e:
LOG.exception("Sending request failed: %s", e)
finally:
self._mdp_client_lock.release()
await self._handle_event(type(resp).__name__, resp)
async def send_request(self, request: Any):
try:
service = request.receiver
self._tasks.add(
asyncio.create_task(self._wait_for_response(service, request))
)
except ValueError:
LOG.error(
"%s cannot determine target service for %s. Please "
"extend %s to provide the 'receiver' property.",
self,
request,
type(request),
)
async def _mdp_worker_transceive(self):
reply = None
while True:
req = await self._mdp_worker.transceive(
reply,
skip_recv=type(reply).__name__.endswith("ShutdownResponse"),
)
if req is None:
break
reply = await self._handle_event(type(req).__name__, req)
def _connect_worker_and_listen(self):
if self._mdp_worker_service is None:
raise ValueError(f"{self}._mdp_worker_service string unset")
if self._mdp_worker is not None:
raise RuntimeError(f"{self} already has an MDP Worker")
self._mdp_worker = MajorDomoWorker(
f"tcp://127.0.0.1:{RuntimeConfig().executor_bus_port}",
self._mdp_worker_service,
)
self._tasks.add(
asyncio.create_task(
self._mdp_worker_transceive(), name="Transceiver"
)
)
def _stop(self, reason: Any = True):
try:
self._future.set_result(reason)
except asyncio.exceptions.InvalidStateError:
# Doubly-stop.
pass
async def _cleanup(self):
await self._stop_all_processes()
for task in self._tasks: # First task is the watcher
task.cancel()
await asyncio.wait(self._tasks)
async def _stop_all_processes(self):
all_processes = list(self._monitored_processes.keys()) # Dict changes
for process in all_processes:
# First see whether the process exits all by itself:
if process.is_alive():
if asyncio.iscoroutinefunction(process.join):
try:
await process.join(15.0)
except asyncio.CancelledError:
pass # This is okay
except asyncio.TimeoutError:
# The process seems unwilling to terminate, but no
# worries, we're not done yet...
pass
else:
process.join(15.0)
if not process.is_alive():
continue # Yay, it ended as we wished.
# Then, send SIGTERM and wait for it to finish:
LOG.warning(
"Process %s did not exit by itself, sending SIGTERM.", process
)
process.terminate()
if asyncio.iscoroutinefunction(process.join):
try:
await process.join(5.0)
except asyncio.CancelledError:
pass
except asyncio.TimeoutError:
pass # Same as above, but now we kill
else:
process.join(5.0)
if not process.is_alive():
continue # Okay, not as nice as it could be, but still...
if all(
not process.is_alive() for process in self._monitored_processes
):
return # Don't wait
# Still someone here? Let's draw the bug friggin' gun:
for process in all_processes:
if process.is_alive():
LOG.error("Process %s is still there, killing it.", process)
process.kill()
if asyncio.iscoroutinefunction(process.join):
try:
await process.join() # This has to terminate.
except asyncio.CancelledError:
pass
except asyncio.TimeoutError:
pass # Yeah, well, we tried. Hand it to the reaper.
else:
process.join()
for task in self._monitored_processes.values():
if not task.done():
task.cancel()
def __str__(self):
return (
f"EventStateMachine(pid={os.getpid()}, "
f"monitored={self._monitored})"
)
def __del__(self):
if not hasattr(self, "_tasks"):
return
for t in self._tasks:
t.cancel()