Source code for palaestrai.util.spawn

import os
import signal
import asyncio
import logging
import logging.config
import logging.handlers
from pathlib import Path
from typing import Callable, Any, Union

import palaestrai.logging
from palaestrai.core import RuntimeConfig

LOG = logging.getLogger(__name__)


def _install_sighandlers():
    signal.signal(signal.SIGINT, signal.SIG_IGN)
    signal.signal(signal.SIGTERM, signal.SIG_DFL)
    signal.signal(signal.SIGCHLD, signal.SIG_DFL)


def _set_proctitle(process_name: str):
    try:
        import setproctitle

        setproctitle.setproctitle(f"palaestrAI[{process_name}]")
    except ImportError:
        pass


def _restore_runtime_configuration(runtime_configuration_dict: dict):
    RuntimeConfig().reset()
    RuntimeConfig().load(runtime_configuration_dict)


def _get_parent_logger_name(name, logger_dict=None):
    """Get the parent logger name based on the hierarchy."""
    if logger_dict is None:
        logger_dict = logging.Logger.manager.loggerDict
    if not name:
        return None
    parts = name.split(".")
    for i in range(len(parts) - 1, 0, -1):
        parent = ".".join(parts[:i])
        if parent in logger_dict:
            return parent
    return None


# Helper to apply filters to all loggers
def _inherit_filters():
    logger_dict = logging.Logger.manager.loggerDict

    for logger_name, logger in logger_dict.items():
        if isinstance(logger, logging.Logger):
            parent = _get_parent_logger_name(
                logger.name, RuntimeConfig().logging["loggers"]
            )
            if parent is not None and parent in logger_dict:
                filters = logger_dict[parent].filters
                for f in filters:
                    logger.addFilter(f)


def _reinitialize_logging():
    try:
        logging.config.dictConfig(RuntimeConfig().logging)
        _inherit_filters()

        logging.addLevelName(palaestrai.logging.ASYNCIO_LOG_LEVEL, "ASYNCIO")

        logging.root.handlers.clear()
        logging.root.addHandler(
            logging.handlers.SocketHandler(
                "127.0.0.1", RuntimeConfig().logger_port
            )
        )
        logging.debug(
            "Reinitialized logging from RuntimeConfig(%s)", RuntimeConfig()
        )

    except (KeyError, ValueError) as e:
        logging.basicConfig(level=logging.INFO)
        logging.warning(
            "Could not load logging config (%s), continuing with defaults",
            e,
        )


[docs] async def spawn_wrapper( name: str, runtime_config: dict, callee: Callable, args: Union[list, None] = None, kwargs: Union[dict, None] = None, ) -> Any: """Wraps a target for fork/spawn and takes care of initialization. Whenever a new subprocess is created (regardless of whether spawn, fork, or forkserver is used), some caretaking needs to be done: * The runtime configuration needs to be transferred, and the ::`RuntimeConfig` properly reinitialized * Logging is reinitialized/rewired to send messages to the parent process * A proctitle is set Parameters ---------- * name : str Name of the process; will lead to a proctitle in the form of ``palaestrai[%s]`` * runtime_config : dict Runtime configuration dict, normally obtained from ::`RuntimeConfig.to_dict` * callee : Callable The target method * args : list, optional Positional arguments of ::`callee`. * kwargs : dict, optional Keyword arguments of ::`callee` Returns ------- Any Whatever the target function returns. """ _install_sighandlers() if name: _set_proctitle(name) if not args: # [] as default arg is mutable, workaround with None: args = [] if not kwargs: # {} as default arg is mutable, workaround with None: kwargs = {} _restore_runtime_configuration(runtime_config) _reinitialize_logging() # We assument that we're using the aiomultiprocess. So we do not need to # initalize a new asyncio event loop here. Should we ever change the # underlying libraries (i.e., not use aiomultiprocess any longer), this # method most become sync instead of async and then we must also re-init # the event loop properly. if RuntimeConfig().profile: import yappi # type: ignore[import-not-found,import-untyped] # Start profiling before launching the loop yappi.set_clock_type("cpu") # or "wall" if you want wall time yappi.start(profile_threads=True) ret = None try: if asyncio.iscoroutinefunction(callee): ret = await callee(*args, **kwargs) else: ret = callee(*args, **kwargs) return ret except Exception as e: LOG.critical("Running %s failed: %s", str(callee), e, exc_info=e) raise e finally: if RuntimeConfig().profile: yappi.stop() # Print stats to console stats = yappi.get_func_stats() stats.sort("ttot") # sort by total time # stats.print_all() # Save to a file for later inspection stats.save( Path(os.curdir) / f"{name}.yappi", type="pstat" ) # can be opened by snakeviz, gprof2dot