diff --git a/taskiq/abc/broker.py b/taskiq/abc/broker.py index 3f91859f..65a03ade 100644 --- a/taskiq/abc/broker.py +++ b/taskiq/abc/broker.py @@ -9,8 +9,10 @@ Any, AsyncGenerator, Callable, + Coroutine, Dict, List, + NoReturn, Optional, TypeVar, Union, @@ -120,14 +122,17 @@ async def kick( """ @abstractmethod - def listen(self) -> AsyncGenerator[BrokerMessage, None]: + async def listen( + self, + callback: Callable[[BrokerMessage], Coroutine[Any, Any, None]], + ) -> None: """ This function listens to new messages and yields them. This it the main point for workers. This function is used to get new tasks from the network. - :yields: taskiq messages. + :param callback: function to call when message received. :return: nothing. """ diff --git a/taskiq/brokers/inmemory_broker.py b/taskiq/brokers/inmemory_broker.py index 131ff469..9eb9b89c 100644 --- a/taskiq/brokers/inmemory_broker.py +++ b/taskiq/brokers/inmemory_broker.py @@ -1,7 +1,7 @@ import inspect from collections import OrderedDict from concurrent.futures import ThreadPoolExecutor -from typing import Any, AsyncGenerator, Callable, Optional, TypeVar +from typing import Any, Callable, Coroutine, Optional, TypeVar from taskiq.abc.broker import AsyncBroker from taskiq.abc.result_backend import AsyncResultBackend, TaskiqResult @@ -139,13 +139,17 @@ async def kick(self, message: BrokerMessage) -> None: except Exception as exc: raise ResultSetError("Cannot set result.") from exc - async def listen(self) -> AsyncGenerator[BrokerMessage, None]: # type: ignore + async def listen( + self, + callback: Callable[[BrokerMessage], Coroutine[Any, Any, None]], + ) -> None: """ Inmemory broker cannot listen. This method throws RuntimeError if you call it. Because inmemory broker cannot really listen to any of tasks. + :param callback: message callback. :raises RuntimeError: if this method is called. """ raise RuntimeError("Inmemory brokers cannot listen.") diff --git a/taskiq/brokers/shared_broker.py b/taskiq/brokers/shared_broker.py index f6a6a5a9..859c1b60 100644 --- a/taskiq/brokers/shared_broker.py +++ b/taskiq/brokers/shared_broker.py @@ -1,4 +1,4 @@ -from typing import AsyncGenerator, Optional, TypeVar +from typing import Any, Callable, Coroutine, Optional, TypeVar from taskiq.abc.broker import AsyncBroker from taskiq.decor import AsyncTaskiqDecoratedTask @@ -59,12 +59,16 @@ async def kick(self, message: BrokerMessage) -> None: "without setting the default_broker.", ) - async def listen(self) -> AsyncGenerator[BrokerMessage, None]: # type: ignore + async def listen( + self, + callback: Callable[[BrokerMessage], Coroutine[Any, Any, None]], + ) -> None: # type: ignore """ Shared broker cannot listen to tasks. This method will throw an exception. + :param callback: message callback. :raises TaskiqError: if called. """ raise TaskiqError("Shared broker cannot listen") diff --git a/taskiq/brokers/zmq_broker.py b/taskiq/brokers/zmq_broker.py index 9f05b4e9..69fcd593 100644 --- a/taskiq/brokers/zmq_broker.py +++ b/taskiq/brokers/zmq_broker.py @@ -1,4 +1,4 @@ -from typing import AsyncGenerator, Callable, Optional, TypeVar +from typing import Any, Callable, Coroutine, Optional, TypeVar from taskiq.abc.broker import AsyncBroker from taskiq.abc.result_backend import AsyncResultBackend @@ -58,12 +58,15 @@ async def kick(self, message: BrokerMessage) -> None: with self.socket.connect(self.sub_host) as sock: await sock.send_string(message.json()) - async def listen(self) -> AsyncGenerator[BrokerMessage, None]: + async def listen( + self, + callback: Callable[[BrokerMessage], Coroutine[Any, Any, None]], + ) -> None: """ Start accepting new messages. - :yield: received broker message + :param callback: function to call when message received. """ while True: # noqa: WPS457 with self.socket.connect(self.sub_host) as sock: - yield BrokerMessage.parse_raw(await sock.recv_string()) + await callback(BrokerMessage.parse_raw(await sock.recv_string())) diff --git a/taskiq/cli/async_task_runner.py b/taskiq/cli/async_task_runner.py index a66195ed..fe58c28b 100644 --- a/taskiq/cli/async_task_runner.py +++ b/taskiq/cli/async_task_runner.py @@ -13,7 +13,7 @@ from taskiq.cli.args import TaskiqArgs from taskiq.cli.log_collector import log_collector from taskiq.context import Context, context_updater -from taskiq.message import TaskiqMessage +from taskiq.message import BrokerMessage, TaskiqMessage from taskiq.result import TaskiqResult from taskiq.utils import maybe_awaitable @@ -180,43 +180,44 @@ async def run_task( # noqa: C901, WPS210, WPS211 return result -async def async_listen_messages( # noqa: C901, WPS210, WPS213 - broker: AsyncBroker, - cli_args: TaskiqArgs, -) -> None: - """ - This function iterates over tasks asynchronously. +class Receiver: + """Class that uses as a callback handler.""" - It uses listen() method of an AsyncBroker - to get new messages from queues. - - :param broker: broker to listen to. - :param cli_args: CLI arguments for worker. - """ - logger.info("Runing startup event.") - await broker.startup() - executor = ThreadPoolExecutor( - max_workers=cli_args.max_threadpool_threads, - ) - logger.info("Listening started.") - task_signatures: Dict[str, inspect.Signature] = {} - for task in broker.available_tasks.values(): + def __init__(self, broker: AsyncBroker, cli_args: TaskiqArgs) -> None: + self.broker = broker + self.cli_args = cli_args + self.task_signatures: Dict[str, inspect.Signature] = {} if not cli_args.no_parse: - task_signatures[task.task_name] = inspect.signature(task.original_func) - async for message in broker.listen(): + for task in self.broker.available_tasks.values(): + self.task_signatures[task.task_name] = inspect.signature( + task.original_func, + ) + self.executor = ThreadPoolExecutor( + max_workers=cli_args.max_threadpool_threads, + ) + + async def callback(self, message: BrokerMessage) -> None: # noqa: C901 + """ + Receive new message and execute tasks. + + This method is used to process message, + that came from brokers. + + :param message: received message. + """ logger.debug(f"Received message: {message}") - if message.task_name not in broker.available_tasks: + if message.task_name not in self.broker.available_tasks: logger.warning( 'task "%s" is not found. Maybe you forgot to import it?', message.task_name, ) - continue + return logger.debug( "Function for task %s is resolved. Executing...", message.task_name, ) try: - taskiq_msg = broker.formatter.loads(message=message) + taskiq_msg = self.broker.formatter.loads(message=message) except Exception as exc: logger.warning( "Cannot parse message: %s. Skipping execution.\n %s", @@ -224,8 +225,8 @@ async def async_listen_messages( # noqa: C901, WPS210, WPS213 exc, exc_info=True, ) - continue - for middleware in broker.middlewares: + return + for middleware in self.broker.middlewares: if middleware.__class__.pre_execute != TaskiqMiddleware.pre_execute: taskiq_msg = await maybe_awaitable( middleware.pre_execute( @@ -238,23 +239,44 @@ async def async_listen_messages( # noqa: C901, WPS210, WPS213 taskiq_msg.task_name, taskiq_msg.task_id, ) - with context_updater(Context(taskiq_msg, broker)): + with context_updater(Context(taskiq_msg, self.broker)): result = await run_task( - target=broker.available_tasks[message.task_name].original_func, - signature=task_signatures.get(message.task_name), + target=self.broker.available_tasks[message.task_name].original_func, + signature=self.task_signatures.get(message.task_name), message=taskiq_msg, - log_collector_format=cli_args.log_collector_format, - executor=executor, - middlewares=broker.middlewares, + log_collector_format=self.cli_args.log_collector_format, + executor=self.executor, + middlewares=self.broker.middlewares, ) - for middleware in broker.middlewares: + for middleware in self.broker.middlewares: if middleware.__class__.post_execute != TaskiqMiddleware.post_execute: await maybe_awaitable(middleware.post_execute(taskiq_msg, result)) try: - await broker.result_backend.set_result(message.task_id, result) + await self.broker.result_backend.set_result(message.task_id, result) except Exception as exc: logger.exception( "Can't set result in result backend. Cause: %s", exc, exc_info=True, ) + + +async def async_listen_messages( + broker: AsyncBroker, + cli_args: TaskiqArgs, +) -> None: + """ + This function iterates over tasks asynchronously. + + It uses listen() method of an AsyncBroker + to get new messages from queues. + + :param broker: broker to listen to. + :param cli_args: CLI arguments for worker. + """ + logger.info("Runing startup event.") + await broker.startup() + logger.info("Inicialized receiver.") + receiver = Receiver(broker, cli_args) + logger.info("Listening started.") + await broker.listen(receiver.callback)