Source code for dramatiq.brokers.stub

# This file is a part of Dramatiq.
#
# Copyright (C) 2017,2018,2019,2020 CLEARTYPE SRL <[email protected]>
#
# Dramatiq is free software; you can redistribute it and/or modify it
# under the terms of the GNU Lesser General Public License as published by
# the Free Software Foundation, either version 3 of the License, or (at
# your option) any later version.
#
# Dramatiq is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public
# License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

from __future__ import annotations

import time
from collections import defaultdict
from itertools import chain
from queue import Empty, Queue
from typing import Optional

from ..broker import Broker, Consumer, MessageProxy
from ..common import current_millis, dq_name, iter_queue, join_queue
from ..errors import QueueNotFound
from ..message import Message
from ..middleware import Middleware


[docs] class StubBroker(Broker): """A broker that can be used within unit tests. Parameters: middleware: See :class:`Broker<dramatiq.Broker>`. fail_fast_default: Specifies the default value for the ``fail_fast`` argument of :meth:`join<dramatiq.brokers.stub.StubBroker.join>`. """ def __init__(self, middleware: Optional[list[Middleware]] = None, *, fail_fast_default: bool = True): super().__init__(middleware) self.dead_letters_by_queue: defaultdict[str, list[MessageProxy]] = defaultdict(list) self.fail_fast_default: bool = fail_fast_default @property def dead_letters(self) -> list[MessageProxy]: """The dead-lettered messages for all defined queues.""" return [message for messages in self.dead_letters_by_queue.values() for message in messages]
[docs] def consume(self, queue_name: str, prefetch: int = 1, timeout: int = 100) -> Consumer: """Create a new consumer for a queue. Parameters: queue_name(str): The queue to consume. prefetch(int): The number of messages to prefetch. timeout(int): The idle timeout in milliseconds. Raises: QueueNotFound: If the queue hasn't been declared. Returns: Consumer: A consumer that retrieves messages from Redis. """ try: return _StubConsumer( self.queues[queue_name], self.dead_letters_by_queue[queue_name], timeout, ) except KeyError: raise QueueNotFound(queue_name) from None
[docs] def declare_queue(self, queue_name: str) -> None: """Declare a queue. Has no effect if a queue with the given name has already been declared. Parameters: queue_name(str): The name of the new queue. """ if queue_name in self.queues: return self.emit_before("declare_queue", queue_name) self.queues[queue_name] = Queue() self.emit_after("declare_queue", queue_name) delayed_name = dq_name(queue_name) self.queues[delayed_name] = Queue() self.delay_queues.add(delayed_name) self.emit_after("declare_delay_queue", delayed_name)
[docs] def enqueue(self, message: Message, *, delay: Optional[int] = None) -> Message: """Enqueue a message. Parameters: message(Message): The message to enqueue. delay(int): The minimum amount of time, in milliseconds, to delay the message by. Raises: QueueNotFound: If the queue the message is being enqueued on doesn't exist. """ queue_name = message.queue_name if delay is not None: queue_name = dq_name(queue_name) message_eta = current_millis() + delay message = message.copy( queue_name=queue_name, options={ "eta": message_eta, }, ) if queue_name not in self.queues: raise QueueNotFound(queue_name) self.emit_before("enqueue", message, delay) self.queues[queue_name].put(message.encode()) self.emit_after("enqueue", message, delay) return message
[docs] def flush(self, queue_name: str) -> None: """Drop all the messages from a queue. Parameters: queue_name(str): The queue to flush. """ for _ in iter_queue(self.queues[queue_name]): self.queues[queue_name].task_done()
[docs] def flush_all(self) -> None: """Drop all messages from all declared queues.""" for queue_name in chain(self.queues, self.delay_queues): self.flush(queue_name) self.dead_letters_by_queue.clear()
[docs] def join(self, queue_name: str, *, timeout: Optional[int] = None, fail_fast: Optional[bool] = None) -> None: """Wait for all the messages on the given queue to be processed. This method is only meant to be used in tests to wait for all the messages in a queue to be processed. Raises: QueueJoinTimeout: When the timeout elapses. QueueNotFound: If the given queue was never declared. Parameters: queue_name(str): The queue to wait on. fail_fast(bool): When this is True and any message gets dead-lettered during the join, then an exception will be raised. When False, no exception will be raised. Defaults to None, which means use the value of the ``fail_fast_default`` instance attribute (which defaults to True). timeout(Optional[int]): The max amount of time, in milliseconds, to wait on this queue. .. versionchanged:: 2.0.0 The ``fail_fast`` parameter now defaults to ``self.fail_fast_default`` (which defaults to True). """ try: queues = [ self.queues[queue_name], self.queues[dq_name(queue_name)], ] except KeyError: raise QueueNotFound(queue_name) from None deadline = timeout and time.monotonic() + timeout / 1000 should_fail_fast = fail_fast if fail_fast is not None else self.fail_fast_default while True: for queue in queues: join_timeout = deadline and deadline - time.monotonic() join_queue(queue, timeout=join_timeout) # We cycle through $queue then $queue.DQ then $queue # again in case the messages that were on the DQ got # moved back on $queue. for queue in queues: if queue.unfinished_tasks: break else: if should_fail_fast: for message in self.dead_letters_by_queue[queue_name]: raise (message._exception or Exception("Message failed with unknown error")) from None return
class _StubConsumer(Consumer): def __init__(self, queue, dead_letters, timeout): self.queue = queue self.dead_letters = dead_letters self.timeout = timeout def ack(self, message): self.queue.task_done() def nack(self, message): self.queue.task_done() self.dead_letters.append(message) def __next__(self): try: data = self.queue.get(timeout=self.timeout / 1000) message = Message.decode(data) return _StubMessageProxy(message) except Empty: return None class _StubMessageProxy(MessageProxy): def clear_exception(self) -> None: """Let the GC handle the cycle once the message is no longer in use. This lets us keep showing full stack traces in failing tests. See comment in `Worker' for details. """