diff --git a/taskiq/brokers/inmemory_broker.py b/taskiq/brokers/inmemory_broker.py index a0b89182..4903c580 100644 --- a/taskiq/brokers/inmemory_broker.py +++ b/taskiq/brokers/inmemory_broker.py @@ -1,6 +1,6 @@ import inspect from collections import OrderedDict -from typing import Any, Callable, Coroutine, Optional, TypeVar +from typing import Any, Callable, Coroutine, Optional, TypeVar, get_type_hints from taskiq.abc.broker import AsyncBroker from taskiq.abc.result_backend import AsyncResultBackend, TaskiqResult @@ -128,6 +128,10 @@ async def kick(self, message: BrokerMessage) -> None: self.receiver.task_signatures[target_task.task_name] = inspect.signature( target_task.original_func, ) + if not self.receiver.task_hints.get(target_task.task_name): + self.receiver.task_hints[target_task.task_name] = get_type_hints( + target_task.original_func, + ) await self.receiver.callback(message=message) diff --git a/taskiq/cli/params_parser.py b/taskiq/cli/params_parser.py index e74e7164..3f862fc1 100644 --- a/taskiq/cli/params_parser.py +++ b/taskiq/cli/params_parser.py @@ -1,6 +1,6 @@ import inspect from logging import getLogger -from typing import Optional +from typing import Any, Dict, Optional from pydantic import parse_obj_as @@ -11,6 +11,7 @@ def parse_params( # noqa: C901 signature: Optional[inspect.Signature], + type_hints: Dict[str, Any], message: TaskiqMessage, ) -> None: """ @@ -42,25 +43,30 @@ def parse_params( # noqa: C901 or you can make some of parameters untyped, or use Any. + Why do we need type_hints separate with + Signature. The reason is simple. + If some variable doesn't have a type hint + it won't be added in the dict of type hints. + :param signature: original function's signature. + :param type_hints: function's type hints. :param message: incoming message. """ if signature is None: return argnum = -1 # Iterate over function's params. - for param_name, params_type in signature.parameters.items(): + for param_name in signature.parameters.keys(): # If parameter doesn't have an annotation. - if params_type.annotation is params_type.empty: + annot = type_hints.get(param_name) + if annot is None: continue # Increment argument numbers. This is # for positional arguments. argnum += 1 - # Shortland for params_type.annotation - annot = params_type.annotation # Value from incoming message. value = None - logger.debug("Trying to parse %s as %s", param_name, params_type.annotation) + logger.debug("Trying to parse %s as %s", param_name, annot) # Check if we have positional arguments in passed message. if argnum < len(message.args): # Get positional argument. diff --git a/taskiq/cli/receiver.py b/taskiq/cli/receiver.py index dd89d2dc..70ed265f 100644 --- a/taskiq/cli/receiver.py +++ b/taskiq/cli/receiver.py @@ -4,7 +4,7 @@ from concurrent.futures import ThreadPoolExecutor from logging import getLogger from time import time -from typing import Any, Callable, Dict, Optional +from typing import Any, Callable, Dict, get_type_hints from taskiq.abc.broker import AsyncBroker from taskiq.abc.middleware import TaskiqMiddleware @@ -20,7 +20,7 @@ def inject_context( - signature: Optional[inspect.Signature], + type_hints: Dict[str, Any], message: TaskiqMessage, broker: AsyncBroker, ) -> None: @@ -33,16 +33,14 @@ def inject_context( If at least one parameter has the Context type, it will add current context as kwarg. - :param signature: function's signature. + :param type_hints: function's type hints. :param message: current taskiq message. :param broker: current broker. """ - if signature is None: + if not type_hints: return - for param_name, param in signature.parameters.items(): - if param.annotation is param.empty: - continue - if param.annotation is Context: + for param_name, param_type in type_hints.items(): + if param_type is Context: message.kwargs[param_name] = Context(message.copy(), broker) @@ -67,8 +65,10 @@ def __init__(self, broker: AsyncBroker, cli_args: TaskiqArgs) -> None: self.broker = broker self.cli_args = cli_args self.task_signatures: Dict[str, inspect.Signature] = {} + self.task_hints: Dict[str, Dict[str, Any]] = {} for task in self.broker.available_tasks.values(): self.task_signatures[task.task_name] = inspect.signature(task.original_func) + self.task_hints[task.task_name] = get_type_hints(task.original_func) self.executor = ThreadPoolExecutor( max_workers=cli_args.max_threadpool_threads, ) @@ -173,9 +173,9 @@ async def run_task( # noqa: C901, WPS210 signature = self.task_signatures.get(message.task_name) if self.cli_args.no_parse: signature = None - parse_params(signature, message) + parse_params(signature, self.task_hints.get(message.task_name) or {}, message) inject_context( - self.task_signatures.get(message.task_name), + self.task_hints.get(message.task_name) or {}, message, self.broker, ) diff --git a/taskiq/cli/tests/test_context.py b/taskiq/cli/tests/test_context.py index f38a91a3..1157afd1 100644 --- a/taskiq/cli/tests/test_context.py +++ b/taskiq/cli/tests/test_context.py @@ -1,4 +1,4 @@ -import inspect +from typing import get_type_hints from taskiq.cli.receiver import inject_context from taskiq.context import Context @@ -20,7 +20,36 @@ def func(param1: int, ctx: Context) -> int: ) inject_context( - inspect.signature(func), + get_type_hints(func), + message=message, + broker=None, # type: ignore + ) + + assert message.kwargs.get("ctx") + assert isinstance(message.kwargs["ctx"], Context) + + +def test_inject_context_success_string_annotation() -> None: + """ + Test that context variable is injected as expected. + + This test checks that if Context was provided as + string, then everything is work as expected. + """ + + def func(param1: int, ctx: "Context") -> int: + return param1 + + message = TaskiqMessage( + task_id="", + task_name="", + labels={}, + args=[1], + kwargs={}, + ) + + inject_context( + get_type_hints(func), message=message, broker=None, # type: ignore ) @@ -44,7 +73,7 @@ def func(param1: int, ctx) -> int: # type: ignore ) inject_context( - inspect.signature(func), + get_type_hints(func), message=message, broker=None, # type: ignore ) @@ -71,7 +100,7 @@ def func(param1: int) -> int: ) inject_context( - inspect.signature(func), + get_type_hints(func), message=message, broker=None, # type: ignore ) diff --git a/taskiq/cli/tests/test_parameters_parsing.py b/taskiq/cli/tests/test_parameters_parsing.py index a05a9888..4f45b32a 100644 --- a/taskiq/cli/tests/test_parameters_parsing.py +++ b/taskiq/cli/tests/test_parameters_parsing.py @@ -1,6 +1,6 @@ import inspect from dataclasses import dataclass -from typing import Any, Type +from typing import Any, Type, get_type_hints import pytest from pydantic import BaseModel @@ -30,6 +30,7 @@ def test_parse_params_no_signature() -> None: modify_msg = src_msg.copy(deep=True) parse_params( signature=None, + type_hints={}, message=modify_msg, ) @@ -51,7 +52,11 @@ def test_func(param: test_class) -> test_class: # type: ignore kwargs={}, ) - parse_params(inspect.signature(test_func), msg_with_args) + parse_params( + inspect.signature(test_func), + get_type_hints(test_func), + msg_with_args, + ) assert isinstance(msg_with_args.args[0], test_class) assert msg_with_args.args[0].field == "test_val" @@ -64,7 +69,11 @@ def test_func(param: test_class) -> test_class: # type: ignore kwargs={"param": {"field": "test_val"}}, ) - parse_params(inspect.signature(test_func), msg_with_kwargs) + parse_params( + inspect.signature(test_func), + get_type_hints(test_func), + msg_with_kwargs, + ) assert isinstance(msg_with_kwargs.kwargs["param"], test_class) assert msg_with_kwargs.kwargs["param"].field == "test_val" @@ -85,7 +94,11 @@ def test_func(param: test_class) -> test_class: # type: ignore kwargs={}, ) - parse_params(inspect.signature(test_func), msg_with_args) + parse_params( + inspect.signature(test_func), + get_type_hints(test_func), + msg_with_args, + ) assert isinstance(msg_with_args.args[0], dict) @@ -97,7 +110,11 @@ def test_func(param: test_class) -> test_class: # type: ignore kwargs={"param": {"unknown": "unknown"}}, ) - parse_params(inspect.signature(test_func), msg_with_kwargs) + parse_params( + inspect.signature(test_func), + get_type_hints(test_func), + msg_with_kwargs, + ) assert isinstance(msg_with_kwargs.kwargs["param"], dict) @@ -117,7 +134,7 @@ def test_func(param: test_class) -> test_class: # type: ignore kwargs={}, ) - parse_params(inspect.signature(test_func), msg_with_args) + parse_params(inspect.signature(test_func), get_type_hints(test_func), msg_with_args) assert msg_with_args.args[0] is None @@ -129,6 +146,10 @@ def test_func(param: test_class) -> test_class: # type: ignore kwargs={"param": None}, ) - parse_params(inspect.signature(test_func), msg_with_kwargs) + parse_params( + inspect.signature(test_func), + get_type_hints(test_func), + msg_with_kwargs, + ) assert msg_with_kwargs.kwargs["param"] is None