diff --git a/.flake8 b/.flake8 index 4df40a5a..d01801b9 100644 --- a/.flake8 +++ b/.flake8 @@ -97,6 +97,8 @@ per-file-ignores = S101, ; Found magic number WPS432, + ; Missing parameter(s) in Docstring + DAR101, exclude = ./.git, diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 4bf73bbe..fc0ab611 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -76,3 +76,7 @@ jobs: - name: Upload coverage reports to Codecov with GitHub Action uses: codecov/codecov-action@v3 if: matrix.os == 'ubuntu-latest' && matrix.py_version == '3.9' + with: + token: ${{ secrets.CODECOV_TOKEN }} + fail_ci_if_error: true + verbose: true diff --git a/taskiq/abc/broker.py b/taskiq/abc/broker.py index 65a03ade..431013b9 100644 --- a/taskiq/abc/broker.py +++ b/taskiq/abc/broker.py @@ -7,12 +7,10 @@ from typing import ( # noqa: WPS235 TYPE_CHECKING, Any, - AsyncGenerator, Callable, Coroutine, Dict, List, - NoReturn, Optional, TypeVar, Union, diff --git a/taskiq/abc/middleware.py b/taskiq/abc/middleware.py index 2bfaabd0..d09e4fdf 100644 --- a/taskiq/abc/middleware.py +++ b/taskiq/abc/middleware.py @@ -1,12 +1,12 @@ from typing import TYPE_CHECKING, Any, Coroutine, Union -if TYPE_CHECKING: +if TYPE_CHECKING: # pragma: no cover from taskiq.abc.broker import AsyncBroker from taskiq.message import TaskiqMessage from taskiq.result import TaskiqResult -class TaskiqMiddleware: +class TaskiqMiddleware: # pragma: no cover """Base class for middlewares.""" def __init__(self) -> None: diff --git a/taskiq/brokers/inmemory_broker.py b/taskiq/brokers/inmemory_broker.py index 9eb9b89c..43bce37f 100644 --- a/taskiq/brokers/inmemory_broker.py +++ b/taskiq/brokers/inmemory_broker.py @@ -1,12 +1,12 @@ import inspect from collections import OrderedDict -from concurrent.futures import ThreadPoolExecutor from typing import Any, Callable, Coroutine, Optional, TypeVar from taskiq.abc.broker import AsyncBroker from taskiq.abc.result_backend import AsyncResultBackend, TaskiqResult -from taskiq.cli.async_task_runner import run_task -from taskiq.exceptions import ResultSetError, TaskiqError +from taskiq.cli.args import TaskiqArgs +from taskiq.cli.receiver import Receiver +from taskiq.exceptions import TaskiqError from taskiq.message import BrokerMessage _ReturnType = TypeVar("_ReturnType") @@ -100,16 +100,16 @@ def __init__( # noqa: WPS211 result_backend=result_backend, task_id_generator=task_id_generator, ) - self.executor = ThreadPoolExecutor(max_workers=sync_tasks_pool_size) - self.cast_types = cast_types - if logs_format is None: - logs_format = ( - "[%(asctime)s]" - "[%(levelname)-7s]" - "[%(module)s:%(funcName)s:%(lineno)d] " - "%(message)s" - ) - self.logs_format = logs_format + self.receiver = Receiver( + self, + TaskiqArgs( + broker="", + modules=[], + max_threadpool_threads=sync_tasks_pool_size, + no_parse=not cast_types, + log_collector_format=logs_format or TaskiqArgs.log_collector_format, + ), + ) async def kick(self, message: BrokerMessage) -> None: """ @@ -119,25 +119,20 @@ async def kick(self, message: BrokerMessage) -> None: :param message: incomming message. - :raises ResultSetError: if cannot save results in result backend. :raises TaskiqError: if someone wants to kick unknown task. """ target_task = self.available_tasks.get(message.task_name) - taskiq_message = self.formatter.loads(message=message) if target_task is None: raise TaskiqError("Unknown task.") - result = await run_task( - target=target_task.original_func, - signature=inspect.signature(target_task.original_func), - message=taskiq_message, - log_collector_format=self.logs_format, - executor=self.executor, - middlewares=self.middlewares, - ) - try: - await self.result_backend.set_result(message.task_id, result) - except Exception as exc: - raise ResultSetError("Cannot set result.") from exc + if self.receiver.task_signatures: + if not self.receiver.task_signatures.get(target_task.task_name): + self.receiver.task_signatures[ + target_task.task_name + ] = inspect.signature( + target_task.original_func, + ) + + await self.receiver.callback(message=message) async def listen( self, diff --git a/taskiq/brokers/zmq_broker.py b/taskiq/brokers/zmq_broker.py index 69fcd593..6713f353 100644 --- a/taskiq/brokers/zmq_broker.py +++ b/taskiq/brokers/zmq_broker.py @@ -1,3 +1,5 @@ +import asyncio +from logging import getLogger from typing import Any, Callable, Coroutine, Optional, TypeVar from taskiq.abc.broker import AsyncBroker @@ -12,6 +14,8 @@ _T = TypeVar("_T") # noqa: WPS111 +logger = getLogger(__name__) + class ZeroMQBroker(AsyncBroker): """ @@ -67,6 +71,13 @@ async def listen( :param callback: function to call when message received. """ - while True: # noqa: WPS457 + loop = asyncio.get_event_loop() + while True: with self.socket.connect(self.sub_host) as sock: - await callback(BrokerMessage.parse_raw(await sock.recv_string())) + received_str = await sock.recv_string() + try: + broker_msg = BrokerMessage.parse_raw(received_str) + except ValueError: + logger.warning("Cannot parse received message %s", received_str) + continue + loop.create_task(callback(broker_msg)) diff --git a/taskiq/cli/args.py b/taskiq/cli/args.py index 4e115618..aacc2e07 100644 --- a/taskiq/cli/args.py +++ b/taskiq/cli/args.py @@ -19,17 +19,19 @@ class TaskiqArgs: """Taskiq worker CLI arguments.""" broker: str - tasks_pattern: str modules: List[str] - fs_discover: bool - log_level: str - workers: int - log_collector_format: str - max_threadpool_threads: int - no_parse: bool - shutdown_timeout: float - reload: bool - no_gitignore: bool + tasks_pattern: str = "tasks.py" + fs_discover: bool = False + log_level: str = "INFO" + workers: int = 2 + log_collector_format: str = ( + "[%(asctime)s][%(levelname)-7s][%(module)s:%(funcName)s:%(lineno)d] %(message)s" + ) + max_threadpool_threads: int = 10 + no_parse: bool = False + shutdown_timeout: float = 5 + reload: bool = False + no_gitignore: bool = False @classmethod def from_cli(cls, args: Optional[List[str]] = None) -> "TaskiqArgs": # noqa: WPS213 @@ -128,8 +130,5 @@ def from_cli(cls, args: Optional[List[str]] = None) -> "TaskiqArgs": # noqa: WP help="Do not use gitignore to check for updated files.", ) - if args is None: - namespace = parser.parse_args(args) - else: - namespace = parser.parse_args() + namespace = parser.parse_args(args) return TaskiqArgs(**namespace.__dict__) diff --git a/taskiq/cli/async_task_runner.py b/taskiq/cli/async_task_runner.py index fe58c28b..8096d641 100644 --- a/taskiq/cli/async_task_runner.py +++ b/taskiq/cli/async_task_runner.py @@ -1,270 +1,16 @@ -import asyncio -import inspect -import io -from concurrent.futures import Executor, ThreadPoolExecutor from logging import getLogger -from time import time -from typing import Any, Callable, Dict, List, Optional - -from pydantic import parse_obj_as from taskiq.abc.broker import AsyncBroker -from taskiq.abc.middleware import TaskiqMiddleware from taskiq.cli.args import TaskiqArgs -from taskiq.cli.log_collector import log_collector -from taskiq.context import Context, context_updater -from taskiq.message import BrokerMessage, TaskiqMessage -from taskiq.result import TaskiqResult -from taskiq.utils import maybe_awaitable +from taskiq.cli.receiver import Receiver logger = getLogger("taskiq.worker") -def parse_params( # noqa: C901 - signature: Optional[inspect.Signature], - message: TaskiqMessage, -) -> None: - """ - Parses incoming parameters. - - This function uses signature to get - expected types of parameters. - - If the parameter from TaskiqMessage - has different type it will try to parse - it. But if parsing fails this function - doesn't modify incoming parameter. - - For example - - you have task like this: - - >>> def my_task(a: int) -> str - >>> ... - - If you will kall my_task.kiq("11") - - You'll receive parsed 11 (int). - But, if you call it with mytask.kiq("str"), - you get the same value. - - If you want to skip parsing completely, - you can pass --no-parse to worker, - or you can make some of parameters untyped, - or use Any. - - :param signature: original function's signature. - :param message: incoming message. - """ - if signature is None: - return - argnum = -1 - for param_name, params_type in signature.parameters.items(): - if params_type.annotation is params_type.empty: - continue - argnum += 1 - annot = params_type.annotation - value = None - logger.debug("Trying to parse %s as %s", param_name, params_type.annotation) - if argnum >= len(message.args): - value = message.kwargs.get(param_name) - if value is None: - continue - try: - message.kwargs[param_name] = parse_obj_as(annot, value) - except (ValueError, RuntimeError) as exc: - logger.debug(exc, exc_info=True) - else: - value = message.args[argnum] - if value is None: - continue - try: - message.args[argnum] = parse_obj_as(annot, value) - except (ValueError, RuntimeError) as exc: - logger.debug(exc, exc_info=True) - - -def run_sync(target: Callable[..., Any], message: TaskiqMessage) -> Any: - """ - Runs function synchronously. - - We use this function, because - we cannot pass kwargs in loop.run_with_executor(). - - :param target: function to execute. - :param message: received message from broker. - :return: result of function's execution. - """ - return target(*message.args, **message.kwargs) - - -async def run_task( # noqa: C901, WPS210, WPS211 - target: Callable[..., Any], - signature: Optional[inspect.Signature], - message: TaskiqMessage, - log_collector_format: str, - executor: Optional[Executor] = None, - middlewares: Optional[List[TaskiqMiddleware]] = None, -) -> TaskiqResult[Any]: - """ - This function actually executes functions. - - It has all needed parameters in - message. - - If the target function is async - it awaits it, if it's sync - it wraps it in run_sync and executes in - threadpool executor. - - Also it uses LogsCollector to - collect logs. - - :param target: function to execute. - :param signature: signature of an original function. - :param message: received message. - :param log_collector_format: Log format in wich logs are collected. - :param executor: executor to run sync tasks. - :param middlewares: list of broker's middlewares in case of errors. - :return: result of execution. - """ - if middlewares is None: - middlewares = [] - - loop = asyncio.get_running_loop() - logs = io.StringIO() - returned = None - found_exception = None - # Captures function's logs. - parse_params(signature, message) - with log_collector(logs, log_collector_format): - start_time = time() - try: - if asyncio.iscoroutinefunction(target): - returned = await target(*message.args, **message.kwargs) - else: - returned = await loop.run_in_executor( - executor, - run_sync, - target, - message, - ) - except Exception as exc: - found_exception = exc - logger.error( - "Exception found while executing function: %s", - exc, - exc_info=True, - ) - execution_time = time() - start_time - - raw_logs = logs.getvalue() - logs.close() - result: "TaskiqResult[Any]" = TaskiqResult( - is_err=found_exception is not None, - log=raw_logs, - return_value=returned, - execution_time=execution_time, - ) - if found_exception is not None: - for middleware in middlewares: - if middleware.__class__.on_error != TaskiqMiddleware.on_error: - await maybe_awaitable( - middleware.on_error( - message, - result, - found_exception, - ), - ) - - return result - - -class Receiver: - """Class that uses as a callback handler.""" - - def __init__(self, broker: AsyncBroker, cli_args: TaskiqArgs) -> None: - self.broker = broker - self.cli_args = cli_args - self.task_signatures: Dict[str, inspect.Signature] = {} - if not cli_args.no_parse: - for task in self.broker.available_tasks.values(): - self.task_signatures[task.task_name] = inspect.signature( - task.original_func, - ) - self.executor = ThreadPoolExecutor( - max_workers=cli_args.max_threadpool_threads, - ) - - async def callback(self, message: BrokerMessage) -> None: # noqa: C901 - """ - Receive new message and execute tasks. - - This method is used to process message, - that came from brokers. - - :param message: received message. - """ - logger.debug(f"Received message: {message}") - if message.task_name not in self.broker.available_tasks: - logger.warning( - 'task "%s" is not found. Maybe you forgot to import it?', - message.task_name, - ) - return - logger.debug( - "Function for task %s is resolved. Executing...", - message.task_name, - ) - try: - taskiq_msg = self.broker.formatter.loads(message=message) - except Exception as exc: - logger.warning( - "Cannot parse message: %s. Skipping execution.\n %s", - message, - exc, - exc_info=True, - ) - return - for middleware in self.broker.middlewares: - if middleware.__class__.pre_execute != TaskiqMiddleware.pre_execute: - taskiq_msg = await maybe_awaitable( - middleware.pre_execute( - taskiq_msg, - ), - ) - - logger.info( - "Executing task %s with ID: %s", - taskiq_msg.task_name, - taskiq_msg.task_id, - ) - with context_updater(Context(taskiq_msg, self.broker)): - result = await run_task( - target=self.broker.available_tasks[message.task_name].original_func, - signature=self.task_signatures.get(message.task_name), - message=taskiq_msg, - log_collector_format=self.cli_args.log_collector_format, - executor=self.executor, - middlewares=self.broker.middlewares, - ) - for middleware in self.broker.middlewares: - if middleware.__class__.post_execute != TaskiqMiddleware.post_execute: - await maybe_awaitable(middleware.post_execute(taskiq_msg, result)) - try: - await self.broker.result_backend.set_result(message.task_id, result) - except Exception as exc: - logger.exception( - "Can't set result in result backend. Cause: %s", - exc, - exc_info=True, - ) - - async def async_listen_messages( broker: AsyncBroker, cli_args: TaskiqArgs, -) -> None: +) -> None: # pragma: no cover """ This function iterates over tasks asynchronously. diff --git a/taskiq/cli/params_parser.py b/taskiq/cli/params_parser.py new file mode 100644 index 00000000..e74e7164 --- /dev/null +++ b/taskiq/cli/params_parser.py @@ -0,0 +1,84 @@ +import inspect +from logging import getLogger +from typing import Optional + +from pydantic import parse_obj_as + +from taskiq.message import TaskiqMessage + +logger = getLogger(__name__) + + +def parse_params( # noqa: C901 + signature: Optional[inspect.Signature], + message: TaskiqMessage, +) -> None: + """ + Parses incoming parameters. + + This function uses signature to get + expected types of parameters. + + If the parameter from TaskiqMessage + has different type it will try to parse + it. But if parsing fails this function + doesn't modify incoming parameter. + + For example + + you have task like this: + + >>> def my_task(a: int) -> str + >>> ... + + If you will kall my_task.kiq("11") + + You'll receive parsed 11 (int). + But, if you call it with mytask.kiq("str"), + you get the same value. + + If you want to skip parsing completely, + you can pass --no-parse to worker, + or you can make some of parameters untyped, + or use Any. + + :param signature: original function's signature. + :param message: incoming message. + """ + if signature is None: + return + argnum = -1 + # Iterate over function's params. + for param_name, params_type in signature.parameters.items(): + # If parameter doesn't have an annotation. + if params_type.annotation is params_type.empty: + continue + # Increment argument numbers. This is + # for positional arguments. + argnum += 1 + # Shortland for params_type.annotation + annot = params_type.annotation + # Value from incoming message. + value = None + logger.debug("Trying to parse %s as %s", param_name, params_type.annotation) + # Check if we have positional arguments in passed message. + if argnum < len(message.args): + # Get positional argument. + value = message.args[argnum] + if value is None: + continue + try: + # trying to parse found value as in type annotation. + message.args[argnum] = parse_obj_as(annot, value) + except (ValueError, RuntimeError) as exc: + logger.debug(exc, exc_info=True) + else: + # We try to get this parameter from kwargs. + value = message.kwargs.get(param_name) + if value is None: + continue + try: + # trying to parse found value as in type annotation. + message.kwargs[param_name] = parse_obj_as(annot, value) + except (ValueError, RuntimeError) as exc: + logger.debug(exc, exc_info=True) diff --git a/taskiq/cli/receiver.py b/taskiq/cli/receiver.py new file mode 100644 index 00000000..46461624 --- /dev/null +++ b/taskiq/cli/receiver.py @@ -0,0 +1,199 @@ +import asyncio +import inspect +import io +from concurrent.futures import ThreadPoolExecutor +from logging import getLogger +from time import time +from typing import Any, Callable, Dict + +from taskiq.abc.broker import AsyncBroker +from taskiq.abc.middleware import TaskiqMiddleware +from taskiq.cli.args import TaskiqArgs +from taskiq.cli.log_collector import log_collector +from taskiq.cli.params_parser import parse_params +from taskiq.context import Context, context_updater +from taskiq.message import BrokerMessage, TaskiqMessage +from taskiq.result import TaskiqResult +from taskiq.utils import maybe_awaitable + +logger = getLogger(__name__) + + +def _run_sync(target: Callable[..., Any], message: TaskiqMessage) -> Any: + """ + Runs function synchronously. + + We use this function, because + we cannot pass kwargs in loop.run_with_executor(). + + :param target: function to execute. + :param message: received message from broker. + :return: result of function's execution. + """ + return target(*message.args, **message.kwargs) + + +class Receiver: + """Class that uses as a callback handler.""" + + def __init__(self, broker: AsyncBroker, cli_args: TaskiqArgs) -> None: + self.broker = broker + self.cli_args = cli_args + self.task_signatures: Dict[str, inspect.Signature] = {} + if not cli_args.no_parse: + for task in self.broker.available_tasks.values(): + self.task_signatures[task.task_name] = inspect.signature( + task.original_func, + ) + self.executor = ThreadPoolExecutor( + max_workers=cli_args.max_threadpool_threads, + ) + + async def callback( # noqa: C901 + self, + message: BrokerMessage, + raise_err: bool = False, + ) -> None: + """ + Receive new message and execute tasks. + + This method is used to process message, + that came from brokers. + + :raises Exception: if raise_err is true, + and excpetion were found while saving result. + :param message: received message. + :param raise_err: raise an error if cannot save result in + result_backend. + """ + logger.debug(f"Received message: {message}") + if message.task_name not in self.broker.available_tasks: + logger.warning( + 'task "%s" is not found. Maybe you forgot to import it?', + message.task_name, + ) + return + logger.debug( + "Function for task %s is resolved. Executing...", + message.task_name, + ) + try: + taskiq_msg = self.broker.formatter.loads(message=message) + except Exception as exc: + logger.warning( + "Cannot parse message: %s. Skipping execution.\n %s", + message, + exc, + exc_info=True, + ) + return + for middleware in self.broker.middlewares: + if middleware.__class__.pre_execute != TaskiqMiddleware.pre_execute: + taskiq_msg = await maybe_awaitable( + middleware.pre_execute( + taskiq_msg, + ), + ) + + logger.info( + "Executing task %s with ID: %s", + taskiq_msg.task_name, + taskiq_msg.task_id, + ) + with context_updater(Context(taskiq_msg, self.broker)): + result = await self.run_task( + target=self.broker.available_tasks[message.task_name].original_func, + message=taskiq_msg, + ) + for middleware in self.broker.middlewares: + if middleware.__class__.post_execute != TaskiqMiddleware.post_execute: + await maybe_awaitable(middleware.post_execute(taskiq_msg, result)) + try: + await self.broker.result_backend.set_result(message.task_id, result) + except Exception as exc: + logger.exception( + "Can't set result in result backend. Cause: %s", + exc, + exc_info=True, + ) + if raise_err: + raise exc + + async def run_task( # noqa: C901, WPS210 + self, + target: Callable[..., Any], + message: TaskiqMessage, + ) -> TaskiqResult[Any]: + """ + This function actually executes functions. + + It has all needed parameters in + message. + + If the target function is async + it awaits it, if it's sync + it wraps it in run_sync and executes in + threadpool executor. + + Also it uses LogsCollector to + collect logs. + + :param target: function to execute. + :param message: received message. + :return: result of execution. + """ + loop = asyncio.get_running_loop() + # Buffer to capture logs. + logs = io.StringIO() + returned = None + found_exception = None + parse_params(self.task_signatures.get(message.task_name), message) + # Captures function's logs. + with log_collector(logs, self.cli_args.log_collector_format): + # Start a timer. + start_time = time() + try: + # If the function is a coroutine we await it. + if asyncio.iscoroutinefunction(target): + returned = await target(*message.args, **message.kwargs) + else: + # If this is a synchronous function we + # run it in executor. + returned = await loop.run_in_executor( + self.executor, + _run_sync, + target, + message, + ) + except Exception as exc: + found_exception = exc + logger.error( + "Exception found while executing function: %s", + exc, + exc_info=True, + ) + # Stop the timer. + execution_time = time() - start_time + + raw_logs = logs.getvalue() + logs.close() + # Assemble result. + result: "TaskiqResult[Any]" = TaskiqResult( + is_err=found_exception is not None, + log=raw_logs, + return_value=returned, + execution_time=execution_time, + ) + # If exception is found we execute middlewares. + if found_exception is not None: + for middleware in self.broker.middlewares: + if middleware.__class__.on_error != TaskiqMiddleware.on_error: + await maybe_awaitable( + middleware.on_error( + message, + result, + found_exception, + ), + ) + + return result diff --git a/taskiq/cli/tests/test_log_collector.py b/taskiq/cli/tests/test_log_collector.py new file mode 100644 index 00000000..90ab335f --- /dev/null +++ b/taskiq/cli/tests/test_log_collector.py @@ -0,0 +1,26 @@ +import logging +import sys +from io import StringIO + +from taskiq.cli.log_collector import log_collector + + +def test_log_collector_std_success() -> None: + """Tests that stdout and stderr calls are collected correctly.""" + log = StringIO() + with log_collector(log, "%(message)s"): + print("log1") # noqa: WPS421 + print("log2", file=sys.stderr) # noqa: WPS421 + assert log.getvalue() == "log1\nlog2\n" + + +def test_log_collector_logging_success() -> None: + """Tests that logging calls are collected correctly.""" + log = StringIO() + with log_collector(log, "%(levelname)s %(message)s"): + logger = logging.getLogger(__name__) + logger.setLevel(logging.DEBUG) + logger.info("log1") + logger.warning("log2") + logger.debug("log3") + assert log.getvalue() == "INFO log1\nWARNING log2\nDEBUG log3\n" diff --git a/taskiq/cli/tests/test_parameters_parsing.py b/taskiq/cli/tests/test_parameters_parsing.py new file mode 100644 index 00000000..a05a9888 --- /dev/null +++ b/taskiq/cli/tests/test_parameters_parsing.py @@ -0,0 +1,134 @@ +import inspect +from dataclasses import dataclass +from typing import Any, Type + +import pytest +from pydantic import BaseModel + +from taskiq.cli.params_parser import parse_params +from taskiq.message import TaskiqMessage + + +class _TestPydanticClass(BaseModel): + field: str + + +@dataclass +class _TestDataclass: + field: str + + +def test_parse_params_no_signature() -> None: + """Test that params aren't parsed if no annotation is supplied.""" + src_msg = TaskiqMessage( + task_id="", + task_name="", + labels={}, + args=[1, 2], + kwargs={"a": 1}, + ) + modify_msg = src_msg.copy(deep=True) + parse_params( + signature=None, + message=modify_msg, + ) + + assert modify_msg == src_msg + + +@pytest.mark.parametrize("test_class", [_TestPydanticClass, _TestDataclass]) +def test_parse_params_classes(test_class: Type[Any]) -> None: + """Test that dataclasses are parsed correctly.""" + + def test_func(param: test_class) -> test_class: # type: ignore + return param + + msg_with_args = TaskiqMessage( + task_id="", + task_name="", + labels={}, + args=[{"field": "test_val"}], + kwargs={}, + ) + + parse_params(inspect.signature(test_func), msg_with_args) + + assert isinstance(msg_with_args.args[0], test_class) + assert msg_with_args.args[0].field == "test_val" + + msg_with_kwargs = TaskiqMessage( + task_id="", + task_name="", + labels={}, + args=[], + kwargs={"param": {"field": "test_val"}}, + ) + + parse_params(inspect.signature(test_func), msg_with_kwargs) + + assert isinstance(msg_with_kwargs.kwargs["param"], test_class) + assert msg_with_kwargs.kwargs["param"].field == "test_val" + + +@pytest.mark.parametrize("test_class", [_TestPydanticClass, _TestDataclass]) +def test_parse_params_wrong_data(test_class: Type[Any]) -> None: + """Tests that wrong data isn't parsed and doesn't throw errors.""" + + def test_func(param: test_class) -> test_class: # type: ignore + return param + + msg_with_args = TaskiqMessage( + task_id="", + task_name="", + labels={}, + args=[{"unknown": "unknown"}], + kwargs={}, + ) + + parse_params(inspect.signature(test_func), msg_with_args) + + assert isinstance(msg_with_args.args[0], dict) + + msg_with_kwargs = TaskiqMessage( + task_id="", + task_name="", + labels={}, + args=[], + kwargs={"param": {"unknown": "unknown"}}, + ) + + parse_params(inspect.signature(test_func), msg_with_kwargs) + + assert isinstance(msg_with_kwargs.kwargs["param"], dict) + + +@pytest.mark.parametrize("test_class", [_TestPydanticClass, _TestDataclass]) +def test_parse_params_nones(test_class: Type[Any]) -> None: + """Tests that None values are not parsed.""" + + def test_func(param: test_class) -> test_class: # type: ignore + return param + + msg_with_args = TaskiqMessage( + task_id="", + task_name="", + labels={}, + args=[None], + kwargs={}, + ) + + parse_params(inspect.signature(test_func), msg_with_args) + + assert msg_with_args.args[0] is None + + msg_with_kwargs = TaskiqMessage( + task_id="", + task_name="", + labels={}, + args=[], + kwargs={"param": None}, + ) + + parse_params(inspect.signature(test_func), msg_with_kwargs) + + assert msg_with_kwargs.kwargs["param"] is None diff --git a/taskiq/cli/tests/test_receiver.py b/taskiq/cli/tests/test_receiver.py new file mode 100644 index 00000000..ae460377 --- /dev/null +++ b/taskiq/cli/tests/test_receiver.py @@ -0,0 +1,202 @@ +from typing import Any, Optional + +import pytest + +from taskiq.abc.broker import AsyncBroker +from taskiq.abc.middleware import TaskiqMiddleware +from taskiq.brokers.inmemory_broker import InMemoryBroker +from taskiq.cli.args import TaskiqArgs +from taskiq.cli.receiver import Receiver +from taskiq.message import BrokerMessage, TaskiqMessage +from taskiq.result import TaskiqResult + + +def get_receiver( + broker: Optional[AsyncBroker] = None, + no_parse: bool = False, +) -> Receiver: + """ + Returns receiver with custom broker and args. + + :param broker: broker, defaults to None + :param no_parse: parameter to taskiq_args, defaults to False + :return: new receiver. + """ + if broker is None: + broker = InMemoryBroker() + return Receiver( + broker, + TaskiqArgs( + broker="", + modules=[], + no_parse=no_parse, + ), + ) + + +@pytest.mark.anyio +async def test_run_task_successfull_async() -> None: + """Tests that run_task can run async tasks.""" + + async def test_func(param: int) -> int: + return param + + receiver = get_receiver() + + result = await receiver.run_task( + test_func, + TaskiqMessage( + task_id="", + task_name="", + labels={}, + args=[1], + kwargs={}, + ), + ) + + assert result.return_value == 1 + + +@pytest.mark.anyio +async def test_run_task_successfull_sync() -> None: + """Tests that run_task can run sync tasks.""" + + def test_func(param: int) -> int: + return param + + receiver = get_receiver() + + result = await receiver.run_task( + test_func, + TaskiqMessage( + task_id="", + task_name="", + labels={}, + args=[1], + kwargs={}, + ), + ) + assert result.return_value == 1 + + +@pytest.mark.anyio +async def test_run_task_exception() -> None: + """Tests that run_task can run sync tasks.""" + + def test_func() -> None: + raise ValueError() + + receiver = get_receiver() + + result = await receiver.run_task( + test_func, + TaskiqMessage( + task_id="", + task_name="", + labels={}, + args=[], + kwargs={}, + ), + ) + assert result.return_value is None + assert result.is_err + + +@pytest.mark.anyio +async def test_run_task_exception_middlewares() -> None: + """Tests that run_task can run sync tasks.""" + + class _TestMiddleware(TaskiqMiddleware): + found_exceptions = [] + + def on_error( + self, + message: "TaskiqMessage", + result: "TaskiqResult[Any]", + exception: Exception, + ) -> None: + self.found_exceptions.append(exception) + + def test_func() -> None: + raise ValueError() + + broker = InMemoryBroker() + broker.add_middlewares([_TestMiddleware()]) + receiver = get_receiver(broker) + + result = await receiver.run_task( + test_func, + TaskiqMessage( + task_id="", + task_name="", + labels={}, + args=[], + kwargs={}, + ), + ) + assert result.return_value is None + assert result.is_err + assert len(_TestMiddleware.found_exceptions) == 1 + assert _TestMiddleware.found_exceptions[0].__class__ == ValueError + + +@pytest.mark.anyio +async def test_callback_success() -> None: + """Test that callback funcion works well.""" + broker = InMemoryBroker() + called_times = 0 + + @broker.task + async def my_task() -> int: + nonlocal called_times # noqa: WPS420 + called_times += 1 + return 1 + + receiver = get_receiver(broker) + + broker_message = broker.formatter.dumps( + TaskiqMessage( + task_id="task_id", + task_name=my_task.task_name, + labels={}, + args=[], + kwargs=[], + ), + ) + + await receiver.callback(broker_message) + assert called_times == 1 + + +@pytest.mark.anyio +async def test_callback_wrong_format() -> None: + """Test that wrong format of a message won't thow an error.""" + receiver = get_receiver() + + await receiver.callback( + BrokerMessage( + task_id="", + task_name="my_task.task_name", + message='{"aaaa": "bbb"}', + labels={}, + ), + ) + + +@pytest.mark.anyio +async def test_callback_unknown_task() -> None: + """Tests that running an unknown task won't throw an error.""" + broker = InMemoryBroker() + receiver = get_receiver(broker) + + broker_message = broker.formatter.dumps( + TaskiqMessage( + task_id="task_id", + task_name="unknown", + labels={}, + args=[], + kwargs=[], + ), + ) + + await receiver.callback(broker_message) diff --git a/taskiq/cli/watcher.py b/taskiq/cli/watcher.py index 2fc7d5e6..30bafe3f 100644 --- a/taskiq/cli/watcher.py +++ b/taskiq/cli/watcher.py @@ -5,7 +5,7 @@ from watchdog.events import FileSystemEvent -class FileWatcher: +class FileWatcher: # pragma: no cover """Filewatcher that watchs for filesystem changes.""" def __init__( diff --git a/taskiq/tests/conftest.py b/taskiq/conftest.py similarity index 100% rename from taskiq/tests/conftest.py rename to taskiq/conftest.py diff --git a/taskiq/result_backends/dummy.py b/taskiq/result_backends/dummy.py index 74edf6e4..dc5cc7f3 100644 --- a/taskiq/result_backends/dummy.py +++ b/taskiq/result_backends/dummy.py @@ -6,7 +6,7 @@ _ReturnType = TypeVar("_ReturnType") -class DummyResultBackend(AsyncResultBackend[_ReturnType]): +class DummyResultBackend(AsyncResultBackend[_ReturnType]): # pragma: no cover """Default result backend, that does nothing.""" async def set_result(self, task_id: str, result: Any) -> None: