diff --git a/docs/examples/introduction/aio_pika_broker.py b/docs/examples/introduction/aio_pika_broker.py index 452340ea..980361db 100644 --- a/docs/examples/introduction/aio_pika_broker.py +++ b/docs/examples/introduction/aio_pika_broker.py @@ -22,6 +22,7 @@ async def main() -> None: print(f"Returned value: {result.return_value}") else: print("Error found while executing task.") + await broker.shutdown() if __name__ == "__main__": diff --git a/docs/examples/introduction/full_example.py b/docs/examples/introduction/full_example.py index 47ed3f2e..1a7e515e 100644 --- a/docs/examples/introduction/full_example.py +++ b/docs/examples/introduction/full_example.py @@ -26,6 +26,7 @@ async def main() -> None: print(f"Returned value: {result.return_value}") else: print("Error found while executing task.") + await broker.shutdown() if __name__ == "__main__": diff --git a/docs/examples/introduction/inmemory_run.py b/docs/examples/introduction/inmemory_run.py index 2bff4954..1e814563 100644 --- a/docs/examples/introduction/inmemory_run.py +++ b/docs/examples/introduction/inmemory_run.py @@ -12,6 +12,7 @@ async def add_one(value: int) -> int: async def main() -> None: + await broker.startup() # Send the task to the broker. task = await add_one.kiq(1) # Wait for the result. @@ -21,6 +22,7 @@ async def main() -> None: print(f"Returned value: {result.return_value}") else: print("Error found while executing task.") + await broker.shutdown() if __name__ == "__main__": diff --git a/docs/examples/state/events_example.py b/docs/examples/state/events_example.py new file mode 100644 index 00000000..50c45abd --- /dev/null +++ b/docs/examples/state/events_example.py @@ -0,0 +1,66 @@ +import asyncio +from typing import Optional + +from redis.asyncio import ConnectionPool, Redis # type: ignore +from taskiq_aio_pika import AioPikaBroker +from taskiq_redis import RedisAsyncResultBackend + +from taskiq import Context, TaskiqEvents, TaskiqState +from taskiq.context import default_context + +# To run this example, please install: +# * taskiq +# * taskiq-redis +# * taskiq-aio-pika + +broker = AioPikaBroker( + "amqp://localhost", + result_backend=RedisAsyncResultBackend( + "redis://localhost/0", + ), +) + + +@broker.on_event(TaskiqEvents.WORKER_STARTUP) +async def startup(state: TaskiqState) -> None: + # Here we store connection pool on startup for later use. + state.redis = ConnectionPool.from_url("redis://localhost/1") + + +@broker.on_event(TaskiqEvents.WORKER_SHUTDOWN) +async def shutdown(state: TaskiqState) -> None: + # Here we close our pool on shutdown event. + await state.redis.disconnect() + + +@broker.task +async def get_val(key: str, context: Context = default_context) -> Optional[str]: + # Now we can use our pool. + redis = Redis(connection_pool=context.state.redis, decode_responses=True) + return await redis.get(key) + + +@broker.task +async def set_val(key: str, value: str, context: Context = default_context) -> None: + # Now we can use our pool to set value. + await Redis(connection_pool=context.state.redis).set(key, value) + + +async def main() -> None: + await broker.startup() + + set_task = await set_val.kiq("key", "value") + set_result = await set_task.wait_result(with_logs=True) + if set_result.is_err: + print(set_result.log) + raise ValueError("Cannot set value in redis. See logs.") + + get_task = await get_val.kiq("key") + get_res = await get_task.wait_result() + print(f"Got redis value: {get_res.return_value}") + + await broker.shutdown() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/docs/guide/getting-started.md b/docs/guide/getting-started.md index d5a4c5cd..0c2abdfc 100644 --- a/docs/guide/getting-started.md +++ b/docs/guide/getting-started.md @@ -54,10 +54,16 @@ from taskiq import InMemoryBroker broker = InMemoryBroker() ``` -And that's it. Now let's add some tasks and the main function. You can add tasks in separate modules. You can find more information about that further. +And that's it. Now let's add some tasks and the main function. You can add tasks in separate modules. You can find more information about that further. Also, we call the `startup` method at the beginning of the `main` function. @[code python](../examples/introduction/inmemory_run.py) +::: tip Cool tip! + +Calling the `startup` method is not required, but we strongly recommend you do so. + +::: + If you run this code, you will get this in your terminal: ```bash:no-line-numbers diff --git a/docs/guide/scheduling-tasks.md b/docs/guide/scheduling-tasks.md index 571eae0b..3857a263 100644 --- a/docs/guide/scheduling-tasks.md +++ b/docs/guide/scheduling-tasks.md @@ -1,5 +1,5 @@ --- -order: 7 +order: 8 --- # Scheduling tasks diff --git a/docs/guide/state-and-events.md b/docs/guide/state-and-events.md new file mode 100644 index 00000000..e5609cd9 --- /dev/null +++ b/docs/guide/state-and-events.md @@ -0,0 +1,32 @@ +--- +order: 7 +--- + +# State and events + +The `TaskiqState` is a global variable where you can keep the variables you want to use later. +For example, you want to open a database connection pool at a broker's startup. + +This can be acieved by adding event handlers. + +You can use one of these events: +* `WORKER_STARTUP` +* `CLIENT_STARTUP` +* `WORKER_SHUTDOWN` +* `CLIENT_SHUTDOWN` + +Worker events are called when you start listening to the broker messages using taskiq. +Client events are called when you call the `startup` method of your broker from your code. + +This is an example of code using event handlers: + +@[code python](../examples/state/events_example.py) + +::: tip Cool tip! + +If you want to add handlers programmatically, you can use the `broker.add_event_handler` function. + +::: + +As you can see in this example, this worker will initialize the Redis pool at the startup. +You can access the state from the context. diff --git a/pyproject.toml b/pyproject.toml index 737c7b50..3774b8f7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,6 +6,8 @@ authors = ["Pavel Kirilin "] maintainers = ["Pavel Kirilin "] readme = "README.md" repository = "https://github.com/taskiq-python/taskiq" +homepage = "https://taskiq-python.github.io/" +documentation = "https://taskiq-python.github.io/" license = "LICENSE" classifiers = [ "Typing :: Typed", @@ -21,7 +23,6 @@ classifiers = [ "Topic :: System :: Networking", "Development Status :: 3 - Alpha", ] -homepage = "https://github.com/taskiq-python/taskiq" keywords = ["taskiq", "tasks", "distributed", "async"] [tool.poetry.dependencies] diff --git a/taskiq/__init__.py b/taskiq/__init__.py index 4a12354c..9199d049 100644 --- a/taskiq/__init__.py +++ b/taskiq/__init__.py @@ -8,11 +8,13 @@ from taskiq.brokers.shared_broker import async_shared_broker from taskiq.brokers.zmq_broker import ZeroMQBroker from taskiq.context import Context +from taskiq.events import TaskiqEvents from taskiq.exceptions import TaskiqError from taskiq.funcs import gather from taskiq.message import BrokerMessage, TaskiqMessage from taskiq.result import TaskiqResult from taskiq.scheduler import ScheduledTask, TaskiqScheduler +from taskiq.state import TaskiqState from taskiq.task import AsyncTaskiqTask __all__ = [ @@ -20,8 +22,10 @@ "Context", "AsyncBroker", "TaskiqError", + "TaskiqState", "TaskiqResult", "ZeroMQBroker", + "TaskiqEvents", "TaskiqMessage", "BrokerMessage", "InMemoryBroker", diff --git a/taskiq/abc/broker.py b/taskiq/abc/broker.py index 431013b9..41813f46 100644 --- a/taskiq/abc/broker.py +++ b/taskiq/abc/broker.py @@ -1,14 +1,16 @@ -import inspect import os import sys from abc import ABC, abstractmethod +from collections import defaultdict from functools import wraps from logging import getLogger from typing import ( # noqa: WPS235 TYPE_CHECKING, Any, + Awaitable, Callable, Coroutine, + DefaultDict, Dict, List, Optional, @@ -18,22 +20,27 @@ ) from uuid import uuid4 -from typing_extensions import ParamSpec +from typing_extensions import ParamSpec, TypeAlias +from taskiq.abc.middleware import TaskiqMiddleware from taskiq.decor import AsyncTaskiqDecoratedTask +from taskiq.events import TaskiqEvents from taskiq.formatters.json_formatter import JSONFormatter from taskiq.message import BrokerMessage from taskiq.result_backends.dummy import DummyResultBackend +from taskiq.state import TaskiqState +from taskiq.utils import maybe_awaitable -if TYPE_CHECKING: +if TYPE_CHECKING: # pragma: no cover from taskiq.abc.formatter import TaskiqFormatter - from taskiq.abc.middleware import TaskiqMiddleware from taskiq.abc.result_backend import AsyncResultBackend _T = TypeVar("_T") # noqa: WPS111 _FuncParams = ParamSpec("_FuncParams") _ReturnType = TypeVar("_ReturnType") +EventHandler: TypeAlias = Callable[[TaskiqState], Optional[Awaitable[None]]] + logger = getLogger("taskiq") @@ -49,7 +56,7 @@ def default_id_generator() -> str: return uuid4().hex -class AsyncBroker(ABC): +class AsyncBroker(ABC): # noqa: WPS230 """ Async broker. @@ -75,8 +82,16 @@ def __init__( self.decorator_class = AsyncTaskiqDecoratedTask self.formatter: "TaskiqFormatter" = JSONFormatter() self.id_generator = task_id_generator - - def add_middlewares(self, middlewares: "List[TaskiqMiddleware]") -> None: + # Every event has a list of handlers. + # Every handler is a function which takes state as a first argument. + # And handler can be either sync or async. + self.event_handlers: DefaultDict[ # noqa: WPS234 + TaskiqEvents, + List[Callable[[TaskiqState], Optional[Awaitable[None]]]], + ] = defaultdict(list) + self.state = TaskiqState() + + def add_middlewares(self, *middlewares: "TaskiqMiddleware") -> None: """ Add a list of middlewares. @@ -86,11 +101,23 @@ def add_middlewares(self, middlewares: "List[TaskiqMiddleware]") -> None: :param middlewares: list of middlewares. """ for middleware in middlewares: + if not isinstance(middleware, TaskiqMiddleware): + logger.warning( + f"Middleware {middleware} is not an instance of TaskiqMiddleware. " + "Skipping...", + ) + continue middleware.set_broker(self) self.middlewares.append(middleware) async def startup(self) -> None: """Do something when starting broker.""" + event = TaskiqEvents.CLIENT_STARTUP + if self.is_worker_process: + event = TaskiqEvents.WORKER_STARTUP + + for handler in self.event_handlers[event]: + await maybe_awaitable(handler(self.state)) async def shutdown(self) -> None: """ @@ -99,11 +126,13 @@ async def shutdown(self) -> None: This method is called, when broker is closig. """ - for middleware in self.middlewares: - middleware_shutdown = middleware.shutdown() - if inspect.isawaitable(middleware_shutdown): - await middleware_shutdown - await self.result_backend.shutdown() + event = TaskiqEvents.CLIENT_SHUTDOWN + if self.is_worker_process: + event = TaskiqEvents.WORKER_SHUTDOWN + + # Call all shutdown events. + for handler in self.event_handlers[event]: + await maybe_awaitable(handler(self.state)) @abstractmethod async def kick( @@ -232,3 +261,43 @@ def inner( inner_task_name=task_name, inner_labels=labels or {}, ) + + def on_event(self, *events: TaskiqEvents) -> Callable[[EventHandler], EventHandler]: + """ + Adds event handler. + + This function adds function to call when event occurs. + + :param events: events to react to. + :return: a decorator function. + """ + + def handler(function: EventHandler) -> EventHandler: + for event in events: + self.event_handlers[event].append(function) + return function + + return handler + + def add_event_handler( + self, + event: TaskiqEvents, + handler: EventHandler, + ) -> None: + """ + Adds event handler. + + this function is the same as on_event. + + >>> broker.add_event_handler(TaskiqEvents.WORKER_STARTUP, my_startup) + + if similar to: + + >>> @broker.on_event(TaskiqEvents.WORKER_STARTUP) + >>> async def my_startup(context: Context) -> None: + >>> ... + + :param event: Event to react to. + :param handler: handler to call when event is started. + """ + self.event_handlers[event].append(handler) diff --git a/taskiq/abc/middleware.py b/taskiq/abc/middleware.py index 5e8df6a4..ffff7d88 100644 --- a/taskiq/abc/middleware.py +++ b/taskiq/abc/middleware.py @@ -1,6 +1,6 @@ from typing import TYPE_CHECKING, Any, Coroutine, Union -if TYPE_CHECKING: # pragma: no cover +if TYPE_CHECKING: # pragma: no cover # pragma: no cover from taskiq.abc.broker import AsyncBroker from taskiq.message import TaskiqMessage from taskiq.result import TaskiqResult @@ -20,9 +20,6 @@ def set_broker(self, broker: "AsyncBroker") -> None: """ self.broker = broker - def shutdown(self) -> Union[None, Coroutine[Any, Any, None]]: - """This function is used to do some work on shutdown.""" - def pre_send( self, message: "TaskiqMessage", diff --git a/taskiq/abc/schedule_source.py b/taskiq/abc/schedule_source.py index f4b35505..7b732e49 100644 --- a/taskiq/abc/schedule_source.py +++ b/taskiq/abc/schedule_source.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod from typing import TYPE_CHECKING, List -if TYPE_CHECKING: +if TYPE_CHECKING: # pragma: no cover from taskiq.scheduler.scheduler import ScheduledTask diff --git a/taskiq/brokers/inmemory_broker.py b/taskiq/brokers/inmemory_broker.py index ae1a9fcc..b83e4ef1 100644 --- a/taskiq/brokers/inmemory_broker.py +++ b/taskiq/brokers/inmemory_broker.py @@ -6,8 +6,10 @@ from taskiq.abc.result_backend import AsyncResultBackend, TaskiqResult from taskiq.cli.worker.args import WorkerArgs from taskiq.cli.worker.receiver import Receiver +from taskiq.events import TaskiqEvents from taskiq.exceptions import TaskiqError from taskiq.message import BrokerMessage +from taskiq.utils import maybe_awaitable _ReturnType = TypeVar("_ReturnType") @@ -149,3 +151,15 @@ async def listen( :raises RuntimeError: if this method is called. """ raise RuntimeError("Inmemory brokers cannot listen.") + + async def startup(self) -> None: + """Runs startup events for client and worker side.""" + for event in (TaskiqEvents.CLIENT_STARTUP, TaskiqEvents.WORKER_STARTUP): + for handler in self.event_handlers.get(event, []): + await maybe_awaitable(handler(self.state)) + + async def shutdown(self) -> None: + """Runs shutdown events for client and worker side.""" + for event in (TaskiqEvents.CLIENT_SHUTDOWN, TaskiqEvents.WORKER_SHUTDOWN): + for handler in self.event_handlers.get(event, []): + await maybe_awaitable(handler(self.state)) diff --git a/taskiq/cli/worker/tests/test_receiver.py b/taskiq/cli/worker/tests/test_receiver.py index 271d38e2..f917c599 100644 --- a/taskiq/cli/worker/tests/test_receiver.py +++ b/taskiq/cli/worker/tests/test_receiver.py @@ -121,7 +121,7 @@ def test_func() -> None: raise ValueError() broker = InMemoryBroker() - broker.add_middlewares([_TestMiddleware()]) + broker.add_middlewares(_TestMiddleware()) receiver = get_receiver(broker) result = await receiver.run_task( diff --git a/taskiq/context.py b/taskiq/context.py index 05f189b1..1251da3c 100644 --- a/taskiq/context.py +++ b/taskiq/context.py @@ -1,6 +1,11 @@ +from typing import TYPE_CHECKING + from taskiq.abc.broker import AsyncBroker from taskiq.message import TaskiqMessage +if TYPE_CHECKING: # pragma: no cover + from taskiq.state import TaskiqState + class Context: """Context class.""" @@ -8,6 +13,9 @@ class Context: def __init__(self, message: TaskiqMessage, broker: AsyncBroker) -> None: self.message = message self.broker = broker + self.state: "TaskiqState" = None # type: ignore + if broker: + self.state = broker.state default_context = Context(None, None) # type: ignore diff --git a/taskiq/decor.py b/taskiq/decor.py index a6f397b9..cb42f6ba 100644 --- a/taskiq/decor.py +++ b/taskiq/decor.py @@ -14,7 +14,7 @@ from taskiq.kicker import AsyncKicker from taskiq.task import AsyncTaskiqTask, SyncTaskiqTask -if TYPE_CHECKING: +if TYPE_CHECKING: # pragma: no cover from taskiq.abc.broker import AsyncBroker _T = TypeVar("_T") # noqa: WPS111 diff --git a/taskiq/events.py b/taskiq/events.py new file mode 100644 index 00000000..854cc99d --- /dev/null +++ b/taskiq/events.py @@ -0,0 +1,20 @@ +import enum + + +@enum.unique +class TaskiqEvents(enum.Enum): + """List of taskiq broker lifetime events.""" + + # Worker events. + + # Called on woker startup. + WORKER_STARTUP = "WORKER_STARTUP" + # Called o worker shutdown. + WORKER_SHUTDOWN = "WORKER_SHUTDOWN" + + # Client events. + + # Called when startup function is called from the client's code. + CLIENT_STARTUP = "CLIENT_STARTUP" + # Called if shutdown function was called from the client's code. + CLIENT_SHUTDOWN = "CLIENT_SHUTDOWN" diff --git a/taskiq/kicker.py b/taskiq/kicker.py index dd422c95..f6676e72 100644 --- a/taskiq/kicker.py +++ b/taskiq/kicker.py @@ -21,7 +21,7 @@ from taskiq.task import AsyncTaskiqTask, SyncTaskiqTask from taskiq.utils import maybe_awaitable, run_sync -if TYPE_CHECKING: +if TYPE_CHECKING: # pragma: no cover from taskiq.abc.broker import AsyncBroker _T = TypeVar("_T") # noqa: WPS111 diff --git a/taskiq/scheduler/merge_functions.py b/taskiq/scheduler/merge_functions.py index ac3b7a1e..331857c7 100644 --- a/taskiq/scheduler/merge_functions.py +++ b/taskiq/scheduler/merge_functions.py @@ -1,6 +1,6 @@ from typing import TYPE_CHECKING, List -if TYPE_CHECKING: +if TYPE_CHECKING: # pragma: no cover from taskiq.scheduler.scheduler import ScheduledTask diff --git a/taskiq/scheduler/scheduler.py b/taskiq/scheduler/scheduler.py index d94b1d86..095f08e3 100644 --- a/taskiq/scheduler/scheduler.py +++ b/taskiq/scheduler/scheduler.py @@ -4,7 +4,7 @@ from taskiq.abc.broker import AsyncBroker from taskiq.scheduler.merge_functions import preserve_all -if TYPE_CHECKING: +if TYPE_CHECKING: # pragma: no cover from taskiq.abc.schedule_source import ScheduleSource diff --git a/taskiq/state.py b/taskiq/state.py new file mode 100644 index 00000000..2947b237 --- /dev/null +++ b/taskiq/state.py @@ -0,0 +1,39 @@ +from collections import UserDict +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: # pragma: no cover + _Base = UserDict[str, Any] +else: + _Base = UserDict + + +class TaskiqState(_Base): + """ + State class. + + This class is used to store useful variables + for later use. + """ + + def __init__(self) -> None: + self.__dict__["data"] = {} + + def __getattr__(self, name: str) -> Any: + try: + return self.__dict__["data"][name] + except KeyError: + cls_name = self.__class__.__name__ + raise AttributeError(f"'{cls_name}' object has no attribute '{name}'") + + def __setattr__(self, name: str, value: Any) -> None: + self[name] = value + + def __delattr__(self, name: str) -> None: # noqa: WPS603 + try: + del self[name] # noqa: WPS420 + except KeyError: + cls_name = self.__class__.__name__ + raise AttributeError(f"'{cls_name}' object has no attribute '{name}'") + + def __str__(self) -> str: + return "TaskiqState(%s)" % super().__str__() diff --git a/taskiq/task.py b/taskiq/task.py index 54439b58..29d01dea 100644 --- a/taskiq/task.py +++ b/taskiq/task.py @@ -10,7 +10,7 @@ ) from taskiq.utils import run_sync -if TYPE_CHECKING: +if TYPE_CHECKING: # pragma: no cover from taskiq.abc.result_backend import AsyncResultBackend from taskiq.result import TaskiqResult diff --git a/taskiq/tests/test_state.py b/taskiq/tests/test_state.py new file mode 100644 index 00000000..07c64e92 --- /dev/null +++ b/taskiq/tests/test_state.py @@ -0,0 +1,58 @@ +from taskiq.state import TaskiqState + + +def test_state_set() -> None: + """Tests that you can sel values as dict items.""" + state = TaskiqState() + state["a"] = 1 + + assert state["a"] == 1 + + +def test_state_get() -> None: + """Tests that you can get values as dict items.""" + state = TaskiqState() + + state["a"] = 1 + + assert state["a"] == 1 + + +def test_state_del() -> None: + """Tests that you can del values as dict items.""" + state = TaskiqState() + + state["a"] = 1 + + del state["a"] # noqa: WPS420 + + assert state.get("a") is None + + +def test_state_set_attr() -> None: + """Tests that you can set values by attribute.""" + state = TaskiqState() + + state.a = 1 + + assert state["a"] == 1 + + +def test_state_get_attr() -> None: + """Tests that you can get values by attribute.""" + state = TaskiqState() + + state["a"] = 1 + + assert state.a == 1 + + +def test_state_del_attr() -> None: + """Tests that you can delete values by attribute.""" + state = TaskiqState() + + state["a"] = 1 + + del state.a # noqa: WPS420 + + assert state.get("a") is None