Source code for dramatiq.brokers.stub

# This file is a part of Dramatiq.
#
# Copyright (C) 2017,2018,2019,2020 CLEARTYPE SRL <[email protected]>
#
# Dramatiq is free software; you can redistribute it and/or modify it
# under the terms of the GNU Lesser General Public License as published by
# the Free Software Foundation, either version 3 of the License, or (at
# your option) any later version.
#
# Dramatiq is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public
# License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
import time
from collections import defaultdict
from itertools import chain
from queue import Empty, Queue

from ..broker import Broker, Consumer, MessageProxy
from ..common import current_millis, dq_name, iter_queue, join_queue
from ..errors import QueueNotFound
from ..message import Message


[docs] class StubBroker(Broker): """A broker that can be used within unit tests. """ def __init__(self, middleware=None): super().__init__(middleware) self.dead_letters_by_queue = defaultdict(list) @property def dead_letters(self): """The dead-lettered messages for all defined queues. """ return [message for messages in self.dead_letters_by_queue.values() for message in messages]
[docs] def consume(self, queue_name, prefetch=1, timeout=100): """Create a new consumer for a queue. Parameters: queue_name(str): The queue to consume. prefetch(int): The number of messages to prefetch. timeout(int): The idle timeout in milliseconds. Raises: QueueNotFound: If the queue hasn't been declared. Returns: Consumer: A consumer that retrieves messages from Redis. """ try: return _StubConsumer( self.queues[queue_name], self.dead_letters_by_queue[queue_name], timeout, ) except KeyError: raise QueueNotFound(queue_name) from None
[docs] def declare_queue(self, queue_name): """Declare a queue. Has no effect if a queue with the given name has already been declared. Parameters: queue_name(str): The name of the new queue. """ if queue_name in self.queues: return self.emit_before("declare_queue", queue_name) self.queues[queue_name] = Queue() self.emit_after("declare_queue", queue_name) delayed_name = dq_name(queue_name) self.queues[delayed_name] = Queue() self.delay_queues.add(delayed_name) self.emit_after("declare_delay_queue", delayed_name)
[docs] def enqueue(self, message, *, delay=None): """Enqueue a message. Parameters: message(Message): The message to enqueue. delay(int): The minimum amount of time, in milliseconds, to delay the message by. Raises: QueueNotFound: If the queue the message is being enqueued on doesn't exist. """ queue_name = message.queue_name if delay is not None: queue_name = dq_name(queue_name) message_eta = current_millis() + delay message = message.copy( queue_name=queue_name, options={ "eta": message_eta, }, ) if queue_name not in self.queues: raise QueueNotFound(queue_name) self.emit_before("enqueue", message, delay) self.queues[queue_name].put(message.encode()) self.emit_after("enqueue", message, delay) return message
[docs] def flush(self, queue_name): """Drop all the messages from a queue. Parameters: queue_name(str): The queue to flush. """ for _ in iter_queue(self.queues[queue_name]): self.queues[queue_name].task_done()
[docs] def flush_all(self): """Drop all messages from all declared queues. """ for queue_name in chain(self.queues, self.delay_queues): self.flush(queue_name) self.dead_letters_by_queue.clear()
# TODO: Make fail_fast default to True.
[docs] def join(self, queue_name, *, fail_fast=False, timeout=None): """Wait for all the messages on the given queue to be processed. This method is only meant to be used in tests to wait for all the messages in a queue to be processed. Raises: QueueJoinTimeout: When the timeout elapses. QueueNotFound: If the given queue was never declared. Parameters: queue_name(str): The queue to wait on. fail_fast(bool): When this is True and any message gets dead-lettered during the join, then an exception will be raised. This will be True by default starting with version 2.0. timeout(Optional[int]): The max amount of time, in milliseconds, to wait on this queue. """ try: queues = [ self.queues[queue_name], self.queues[dq_name(queue_name)], ] except KeyError: raise QueueNotFound(queue_name) from None deadline = timeout and time.monotonic() + timeout / 1000 while True: for queue in queues: timeout = deadline and deadline - time.monotonic() join_queue(queue, timeout=timeout) # We cycle through $queue then $queue.DQ then $queue # again in case the messages that were on the DQ got # moved back on $queue. for queue in queues: if queue.unfinished_tasks: break else: if fail_fast: for message in self.dead_letters_by_queue[queue_name]: raise message._exception from None return
class _StubConsumer(Consumer): def __init__(self, queue, dead_letters, timeout): self.queue = queue self.dead_letters = dead_letters self.timeout = timeout def ack(self, message): self.queue.task_done() def nack(self, message): self.queue.task_done() self.dead_letters.append(message) def __next__(self): try: data = self.queue.get(timeout=self.timeout / 1000) message = Message.decode(data) return _StubMessageProxy(message) except Empty: return None class _StubMessageProxy(MessageProxy): def clear_exception(self): """Let the GC handle the cycle once the message is no longer in use. This lets us keep showing full stack traces in failing tests. See comment in `Worker' for details. """