From 83a19220a34f64154b49974034994c534b2d9ef7 Mon Sep 17 00:00:00 2001 From: Pavel Kirilin Date: Tue, 23 Aug 2022 18:10:46 +0400 Subject: [PATCH] Added context variable. Signed-off-by: Pavel Kirilin --- taskiq/cli/async_task_runner.py | 26 ++++++++-------- taskiq/context.py | 54 +++++++++++++++++++++++++++++++++ 2 files changed, 68 insertions(+), 12 deletions(-) create mode 100644 taskiq/context.py diff --git a/taskiq/cli/async_task_runner.py b/taskiq/cli/async_task_runner.py index 1abd619a..a66195ed 100644 --- a/taskiq/cli/async_task_runner.py +++ b/taskiq/cli/async_task_runner.py @@ -12,6 +12,7 @@ from taskiq.abc.middleware import TaskiqMiddleware from taskiq.cli.args import TaskiqArgs from taskiq.cli.log_collector import log_collector +from taskiq.context import Context, context_updater from taskiq.message import TaskiqMessage from taskiq.result import TaskiqResult from taskiq.utils import maybe_awaitable @@ -224,11 +225,6 @@ async def async_listen_messages( # noqa: C901, WPS210, WPS213 exc_info=True, ) continue - logger.info( - "Executing task %s with ID: %s", - taskiq_msg.task_name, - taskiq_msg.task_id, - ) for middleware in broker.middlewares: if middleware.__class__.pre_execute != TaskiqMiddleware.pre_execute: taskiq_msg = await maybe_awaitable( @@ -237,14 +233,20 @@ async def async_listen_messages( # noqa: C901, WPS210, WPS213 ), ) - result = await run_task( - target=broker.available_tasks[message.task_name].original_func, - signature=task_signatures.get(message.task_name), - message=taskiq_msg, - log_collector_format=cli_args.log_collector_format, - executor=executor, - middlewares=broker.middlewares, + logger.info( + "Executing task %s with ID: %s", + taskiq_msg.task_name, + taskiq_msg.task_id, ) + with context_updater(Context(taskiq_msg, broker)): + result = await run_task( + target=broker.available_tasks[message.task_name].original_func, + signature=task_signatures.get(message.task_name), + message=taskiq_msg, + log_collector_format=cli_args.log_collector_format, + executor=executor, + middlewares=broker.middlewares, + ) for middleware in broker.middlewares: if middleware.__class__.post_execute != TaskiqMiddleware.post_execute: await maybe_awaitable(middleware.post_execute(taskiq_msg, result)) diff --git a/taskiq/context.py b/taskiq/context.py new file mode 100644 index 00000000..8ccbc8fc --- /dev/null +++ b/taskiq/context.py @@ -0,0 +1,54 @@ +from contextlib import contextmanager +from typing import Generator + +from taskiq.abc.broker import AsyncBroker +from taskiq.message import TaskiqMessage + + +class Context: + """Context class.""" + + def __init__(self, message: TaskiqMessage, broker: AsyncBroker) -> None: + self.message = message + self.broker = broker + + +default_context = Context(None, None) # type: ignore +current_context = None + + +@contextmanager +def context_updater(new_context: Context) -> Generator[None, None, None]: + """ + Update context for some time. + + :param new_context: new context to set. + :yield: nothing. + """ + global current_context # noqa: WPS420 + current_context = new_context # noqa: WPS442 + + yield + + current_context = None # noqa: WPS442 + + +def get_context() -> Context: + """ + Get current context. + + This function always return contexts, + but if you call this function inside tests, + or somewhere you have to be careful, + since if current_context is None it will + return default_context. + + To override context please use context_updater + context manager. + + :return: context. + """ + global current_context # noqa: WPS420 + if current_context is None: + return default_context + return current_context