Source code for dramatiq.brokers.redis

# This file is a part of Dramatiq.
# Copyright (C) 2017,2018 CLEARTYPE SRL <[email protected]>
# Dramatiq is free software; you can redistribute it and/or modify it
# under the terms of the GNU Lesser General Public License as published by
# the Free Software Foundation, either version 3 of the License, or (at
# your option) any later version.
# Dramatiq is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
# License for more details.
# You should have received a copy of the GNU Lesser General Public License
# along with this program.  If not, see <>.

import glob
import random
import time
import warnings
from os import path
from uuid import uuid4

import redis

from import Broker, Consumer, MessageProxy
from ..common import compute_backoff, current_millis, dq_name
from ..errors import ConnectionClosed, QueueJoinTimeout
from ..logging import get_logger
from ..message import Message


#: How many commands out of a million should trigger queue
#: maintenance.

#: The amount of time in milliseconds that dead-lettered messages are
#: kept in Redis for.

#: The amount of time in milliseconds that has to pass without a
#: heartbeat for a worker to be considered offline.

[docs]class RedisBroker(Broker): """A broker than can be used with Redis. Examples: If you want to specify connection parameters individually: >>> RedisBroker(host="", port=6379, db=0, password="hunter2") Alternatively, if you want to use a connection URL: >>> RedisBroker(url="redis://") See also: Redis_ for a list of all the available connection parameters. Parameters: url(str): An optional connection URL. If both a URL and connection parameters are provided, the URL is used. middleware(list[Middleware]) maintenance_chance(int): How many commands out of a million should trigger queue maintenance. namespace(str): The str with which to prefix all Redis keys. heartbeat_timeout(int): The amount of time (in ms) that has to pass without a heartbeat for a broker process to be considered offline. dead_message_ttl(int): The amount of time (in ms) that dead-lettered messages are kept in Redis for. requeue_deadline(int): Deprecated. Does nothing. requeue_interval(int): Deprecated. Does nothing. **parameters(dict): Connection parameters are passed directly to :class:`redis.Redis`. .. _Redis: """ def __init__( self, *, url=None, middleware=None, namespace="dramatiq", maintenance_chance=DEFAULT_MAINTENANCE_CHANCE, heartbeat_timeout=DEFAULT_HEARTBEAT_TIMEOUT, dead_message_ttl=DEFAULT_DEAD_MESSAGE_TTL, requeue_deadline=None, requeue_interval=None, **parameters ): super().__init__(middleware=middleware) if url: parameters["connection_pool"] = redis.ConnectionPool.from_url(url) if requeue_deadline or requeue_interval: message = "requeue_{deadline,interval} have been deprecated and no longer do anything" warnings.warn(message, DeprecationWarning, stacklevel=2) self.broker_id = str(uuid4()) self.namespace = namespace self.maintenance_chance = maintenance_chance self.heartbeat_timeout = heartbeat_timeout self.dead_message_ttl = dead_message_ttl self.queues = set() # TODO: Replace usages of StrictRedis (redis-py 2.x) with Redis in Dramatiq 2.0. self.client = client = redis.StrictRedis(**parameters) self.scripts = {name: client.register_script(script) for name, script in _scripts.items()}
[docs] def consume(self, queue_name, prefetch=1, timeout=5000): """Create a new consumer for a queue. Parameters: queue_name(str): The queue to consume. prefetch(int): The number of messages to prefetch. timeout(int): The idle timeout in milliseconds. Returns: Consumer: A consumer that retrieves messages from Redis. """ return _RedisConsumer(self, queue_name, prefetch, timeout)
[docs] def declare_queue(self, queue_name): """Declare a queue. Has no effect if a queue with the given name has already been declared. Parameters: queue_name(str): The name of the new queue. """ if queue_name not in self.queues: self.emit_before("declare_queue", queue_name) self.queues.add(queue_name) self.emit_after("declare_queue", queue_name) delayed_name = dq_name(queue_name) self.delay_queues.add(delayed_name) self.emit_after("declare_delay_queue", delayed_name)
[docs] def enqueue(self, message, *, delay=None): """Enqueue a message. Parameters: message(Message): The message to enqueue. delay(int): The minimum amount of time, in milliseconds, to delay the message by. Must be less than 7 days. Raises: ValueError: If ``delay`` is longer than 7 days. """ queue_name = message.queue_name # Each enqueued message must have a unique id in Redis so # using the Message's id isn't safe because messages may be # retried. message = message.copy(options={ "redis_message_id": str(uuid4()), }) if delay is not None: queue_name = dq_name(queue_name) message_eta = current_millis() + delay message = message.copy( queue_name=queue_name, options={ "eta": message_eta, }, ) self.logger.debug("Enqueueing message %r on queue %r.", message.message_id, queue_name) self.emit_before("enqueue", message, delay) self.do_enqueue(queue_name, message.options["redis_message_id"], message.encode()) self.emit_after("enqueue", message, delay) return message
[docs] def get_declared_queues(self): """Get all declared queues. Returns: set[str]: The names of all the queues declared so far on this Broker. """ return self.queues.copy()
[docs] def flush(self, queue_name): """Drop all the messages from a queue. Parameters: queue_name(str): The queue to flush. """ for name in (queue_name, dq_name(queue_name)): self.do_purge(name)
[docs] def flush_all(self): """Drop all messages from all declared queues. """ for queue_name in self.queues: self.flush(queue_name)
[docs] def join(self, queue_name, *, interval=100, timeout=None): """Wait for all the messages on the given queue to be processed. This method is only meant to be used in tests to wait for all the messages in a queue to be processed. Raises: QueueJoinTimeout: When the timeout elapses. Parameters: queue_name(str): The queue to wait on. interval(Optional[int]): The interval, in milliseconds, at which to check the queues. timeout(Optional[int]): The max amount of time, in milliseconds, to wait on this queue. """ deadline = timeout and time.monotonic() + timeout / 1000 while True: if deadline and time.monotonic() >= deadline: raise QueueJoinTimeout(queue_name) size = 0 for name in (queue_name, dq_name(queue_name)): size += self.do_qsize(name) if size == 0: return time.sleep(interval / 1000)
def _should_do_maintenance(self, command): return int( command not in MAINTENANCE_COMMAND_BLACKLIST and random.randint(1, MAINTENANCE_SCALE) <= self.maintenance_chance ) def _dispatch(self, command): # Micro-optimization: by hoisting these up here we avoid # allocating the list on every call. dispatch = self.scripts["dispatch"] keys = [self.namespace] def do_dispatch(queue_name, *args): timestamp = current_millis() args = [ command, timestamp, queue_name, self.broker_id, self.heartbeat_timeout, self.dead_message_ttl, self._should_do_maintenance(command), *args, ] return dispatch(args=args, keys=keys) return do_dispatch def __getattr__(self, name): if not name.startswith("do_"): raise AttributeError("attribute %s does not exist" % name) command = name[len("do_"):] return self._dispatch(command)
class _RedisConsumer(Consumer): def __init__(self, broker, queue_name, prefetch, timeout): self.logger = get_logger(__name__, type(self)) = broker self.queue_name = queue_name self.prefetch = prefetch self.timeout = timeout self.message_cache = [] self.message_refc = 0 self.misses = 0 def ack(self, message): try: # The current queue might be different from message.queue_name # if the message has been delayed so we want to ack on the # current queue., message.options["redis_message_id"]) except redis.ConnectionError as e: raise ConnectionClosed(e) from None finally: self.message_refc -= 1 def nack(self, message): try: # Same deal as above., message.options["redis_message_id"]) except redis.ConnectionError as e: raise ConnectionClosed(e) from None finally: self.message_refc -= 1 def requeue(self, messages): message_ids = [message.options["redis_message_id"] for message in messages] if not message_ids: return self.logger.debug("Re-enqueueing %r on queue %r.", message_ids, self.queue_name), *message_ids) def __next__(self): try: while True: try: # This is a micro-optimization so we try the fast # path first. We assume there are messages in the # cache and if there aren't, we go down the slow # path of doing network IO. data = self.message_cache.pop(0) self.misses = 0 message = Message.decode(data) return MessageProxy(message) except IndexError: # If there are fewer messages currently being # processed than we're allowed to prefetch, # prefetch up to that number of messages. messages = [] if self.message_refc < self.prefetch: self.message_cache = messages = self.queue_name, self.prefetch - self.message_refc, ) # Because we didn't get any messages, we should # progressively long poll up to the idle timeout. if not messages: self.misses, backoff_ms = compute_backoff(self.misses, max_backoff=self.timeout) time.sleep(backoff_ms / 1000) return None # Since we received some number of messages, we # have to keep track of them. self.message_refc += len(messages) except redis.ConnectionError as e: raise ConnectionClosed(e) from None _scripts = {} _scripts_path = path.join(path.abspath(path.dirname(__file__)), "redis") for filename in glob.glob(path.join(_scripts_path, "*.lua")): script_name, _ = path.splitext(path.basename(filename)) with open(filename, "rb") as f: _scripts[script_name] =