Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion taskiq/brokers/inmemory_broker.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import inspect
from collections import OrderedDict
from typing import Any, Callable, Coroutine, Optional, TypeVar
from typing import Any, Callable, Coroutine, Optional, TypeVar, get_type_hints

from taskiq.abc.broker import AsyncBroker
from taskiq.abc.result_backend import AsyncResultBackend, TaskiqResult
Expand Down Expand Up @@ -128,6 +128,10 @@ async def kick(self, message: BrokerMessage) -> None:
self.receiver.task_signatures[target_task.task_name] = inspect.signature(
target_task.original_func,
)
if not self.receiver.task_hints.get(target_task.task_name):
self.receiver.task_hints[target_task.task_name] = get_type_hints(
target_task.original_func,
)

await self.receiver.callback(message=message)

Expand Down
18 changes: 12 additions & 6 deletions taskiq/cli/params_parser.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import inspect
from logging import getLogger
from typing import Optional
from typing import Any, Dict, Optional

from pydantic import parse_obj_as

Expand All @@ -11,6 +11,7 @@

def parse_params( # noqa: C901
signature: Optional[inspect.Signature],
type_hints: Dict[str, Any],
message: TaskiqMessage,
) -> None:
"""
Expand Down Expand Up @@ -42,25 +43,30 @@ def parse_params( # noqa: C901
or you can make some of parameters untyped,
or use Any.

Why do we need type_hints separate with
Signature. The reason is simple.
If some variable doesn't have a type hint
it won't be added in the dict of type hints.

:param signature: original function's signature.
:param type_hints: function's type hints.
:param message: incoming message.
"""
if signature is None:
return
argnum = -1
# Iterate over function's params.
for param_name, params_type in signature.parameters.items():
for param_name in signature.parameters.keys():
# If parameter doesn't have an annotation.
if params_type.annotation is params_type.empty:
annot = type_hints.get(param_name)
if annot is None:
continue
# Increment argument numbers. This is
# for positional arguments.
argnum += 1
# Shortland for params_type.annotation
annot = params_type.annotation
# Value from incoming message.
value = None
logger.debug("Trying to parse %s as %s", param_name, params_type.annotation)
logger.debug("Trying to parse %s as %s", param_name, annot)
# Check if we have positional arguments in passed message.
if argnum < len(message.args):
# Get positional argument.
Expand Down
20 changes: 10 additions & 10 deletions taskiq/cli/receiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from concurrent.futures import ThreadPoolExecutor
from logging import getLogger
from time import time
from typing import Any, Callable, Dict, Optional
from typing import Any, Callable, Dict, get_type_hints

from taskiq.abc.broker import AsyncBroker
from taskiq.abc.middleware import TaskiqMiddleware
Expand All @@ -20,7 +20,7 @@


def inject_context(
signature: Optional[inspect.Signature],
type_hints: Dict[str, Any],
message: TaskiqMessage,
broker: AsyncBroker,
) -> None:
Expand All @@ -33,16 +33,14 @@ def inject_context(
If at least one parameter has the Context
type, it will add current context as kwarg.

:param signature: function's signature.
:param type_hints: function's type hints.
:param message: current taskiq message.
:param broker: current broker.
"""
if signature is None:
if not type_hints:
return
for param_name, param in signature.parameters.items():
if param.annotation is param.empty:
continue
if param.annotation is Context:
for param_name, param_type in type_hints.items():
if param_type is Context:
message.kwargs[param_name] = Context(message.copy(), broker)


Expand All @@ -67,8 +65,10 @@ def __init__(self, broker: AsyncBroker, cli_args: TaskiqArgs) -> None:
self.broker = broker
self.cli_args = cli_args
self.task_signatures: Dict[str, inspect.Signature] = {}
self.task_hints: Dict[str, Dict[str, Any]] = {}
for task in self.broker.available_tasks.values():
self.task_signatures[task.task_name] = inspect.signature(task.original_func)
self.task_hints[task.task_name] = get_type_hints(task.original_func)
self.executor = ThreadPoolExecutor(
max_workers=cli_args.max_threadpool_threads,
)
Expand Down Expand Up @@ -173,9 +173,9 @@ async def run_task( # noqa: C901, WPS210
signature = self.task_signatures.get(message.task_name)
if self.cli_args.no_parse:
signature = None
parse_params(signature, message)
parse_params(signature, self.task_hints.get(message.task_name) or {}, message)
inject_context(
self.task_signatures.get(message.task_name),
self.task_hints.get(message.task_name) or {},
message,
self.broker,
)
Expand Down
37 changes: 33 additions & 4 deletions taskiq/cli/tests/test_context.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import inspect
from typing import get_type_hints

from taskiq.cli.receiver import inject_context
from taskiq.context import Context
Expand All @@ -20,7 +20,36 @@ def func(param1: int, ctx: Context) -> int:
)

inject_context(
inspect.signature(func),
get_type_hints(func),
message=message,
broker=None, # type: ignore
)

assert message.kwargs.get("ctx")
assert isinstance(message.kwargs["ctx"], Context)


def test_inject_context_success_string_annotation() -> None:
"""
Test that context variable is injected as expected.

This test checks that if Context was provided as
string, then everything is work as expected.
"""

def func(param1: int, ctx: "Context") -> int:
return param1

message = TaskiqMessage(
task_id="",
task_name="",
labels={},
args=[1],
kwargs={},
)

inject_context(
get_type_hints(func),
message=message,
broker=None, # type: ignore
)
Expand All @@ -44,7 +73,7 @@ def func(param1: int, ctx) -> int: # type: ignore
)

inject_context(
inspect.signature(func),
get_type_hints(func),
message=message,
broker=None, # type: ignore
)
Expand All @@ -71,7 +100,7 @@ def func(param1: int) -> int:
)

inject_context(
inspect.signature(func),
get_type_hints(func),
message=message,
broker=None, # type: ignore
)
Expand Down
35 changes: 28 additions & 7 deletions taskiq/cli/tests/test_parameters_parsing.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import inspect
from dataclasses import dataclass
from typing import Any, Type
from typing import Any, Type, get_type_hints

import pytest
from pydantic import BaseModel
Expand Down Expand Up @@ -30,6 +30,7 @@ def test_parse_params_no_signature() -> None:
modify_msg = src_msg.copy(deep=True)
parse_params(
signature=None,
type_hints={},
message=modify_msg,
)

Expand All @@ -51,7 +52,11 @@ def test_func(param: test_class) -> test_class: # type: ignore
kwargs={},
)

parse_params(inspect.signature(test_func), msg_with_args)
parse_params(
inspect.signature(test_func),
get_type_hints(test_func),
msg_with_args,
)

assert isinstance(msg_with_args.args[0], test_class)
assert msg_with_args.args[0].field == "test_val"
Expand All @@ -64,7 +69,11 @@ def test_func(param: test_class) -> test_class: # type: ignore
kwargs={"param": {"field": "test_val"}},
)

parse_params(inspect.signature(test_func), msg_with_kwargs)
parse_params(
inspect.signature(test_func),
get_type_hints(test_func),
msg_with_kwargs,
)

assert isinstance(msg_with_kwargs.kwargs["param"], test_class)
assert msg_with_kwargs.kwargs["param"].field == "test_val"
Expand All @@ -85,7 +94,11 @@ def test_func(param: test_class) -> test_class: # type: ignore
kwargs={},
)

parse_params(inspect.signature(test_func), msg_with_args)
parse_params(
inspect.signature(test_func),
get_type_hints(test_func),
msg_with_args,
)

assert isinstance(msg_with_args.args[0], dict)

Expand All @@ -97,7 +110,11 @@ def test_func(param: test_class) -> test_class: # type: ignore
kwargs={"param": {"unknown": "unknown"}},
)

parse_params(inspect.signature(test_func), msg_with_kwargs)
parse_params(
inspect.signature(test_func),
get_type_hints(test_func),
msg_with_kwargs,
)

assert isinstance(msg_with_kwargs.kwargs["param"], dict)

Expand All @@ -117,7 +134,7 @@ def test_func(param: test_class) -> test_class: # type: ignore
kwargs={},
)

parse_params(inspect.signature(test_func), msg_with_args)
parse_params(inspect.signature(test_func), get_type_hints(test_func), msg_with_args)

assert msg_with_args.args[0] is None

Expand All @@ -129,6 +146,10 @@ def test_func(param: test_class) -> test_class: # type: ignore
kwargs={"param": None},
)

parse_params(inspect.signature(test_func), msg_with_kwargs)
parse_params(
inspect.signature(test_func),
get_type_hints(test_func),
msg_with_kwargs,
)

assert msg_with_kwargs.kwargs["param"] is None