Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 14 additions & 12 deletions taskiq/cli/async_task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from taskiq.abc.middleware import TaskiqMiddleware
from taskiq.cli.args import TaskiqArgs
from taskiq.cli.log_collector import log_collector
from taskiq.context import Context, context_updater
from taskiq.message import TaskiqMessage
from taskiq.result import TaskiqResult
from taskiq.utils import maybe_awaitable
Expand Down Expand Up @@ -224,11 +225,6 @@ async def async_listen_messages( # noqa: C901, WPS210, WPS213
exc_info=True,
)
continue
logger.info(
"Executing task %s with ID: %s",
taskiq_msg.task_name,
taskiq_msg.task_id,
)
for middleware in broker.middlewares:
if middleware.__class__.pre_execute != TaskiqMiddleware.pre_execute:
taskiq_msg = await maybe_awaitable(
Expand All @@ -237,14 +233,20 @@ async def async_listen_messages( # noqa: C901, WPS210, WPS213
),
)

result = await run_task(
target=broker.available_tasks[message.task_name].original_func,
signature=task_signatures.get(message.task_name),
message=taskiq_msg,
log_collector_format=cli_args.log_collector_format,
executor=executor,
middlewares=broker.middlewares,
logger.info(
"Executing task %s with ID: %s",
taskiq_msg.task_name,
taskiq_msg.task_id,
)
with context_updater(Context(taskiq_msg, broker)):
result = await run_task(
target=broker.available_tasks[message.task_name].original_func,
signature=task_signatures.get(message.task_name),
message=taskiq_msg,
log_collector_format=cli_args.log_collector_format,
executor=executor,
middlewares=broker.middlewares,
)
for middleware in broker.middlewares:
if middleware.__class__.post_execute != TaskiqMiddleware.post_execute:
await maybe_awaitable(middleware.post_execute(taskiq_msg, result))
Expand Down
54 changes: 54 additions & 0 deletions taskiq/context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from contextlib import contextmanager
from typing import Generator

from taskiq.abc.broker import AsyncBroker
from taskiq.message import TaskiqMessage


class Context:
"""Context class."""

def __init__(self, message: TaskiqMessage, broker: AsyncBroker) -> None:
self.message = message
self.broker = broker


default_context = Context(None, None) # type: ignore
current_context = None


@contextmanager
def context_updater(new_context: Context) -> Generator[None, None, None]:
"""
Update context for some time.

:param new_context: new context to set.
:yield: nothing.
"""
global current_context # noqa: WPS420
current_context = new_context # noqa: WPS442

yield

current_context = None # noqa: WPS442


def get_context() -> Context:
"""
Get current context.

This function always return contexts,
but if you call this function inside tests,
or somewhere you have to be careful,
since if current_context is None it will
return default_context.

To override context please use context_updater
context manager.

:return: context.
"""
global current_context # noqa: WPS420
if current_context is None:
return default_context
return current_context