Sindbad~EG File Manager
#!/usr/bin/env python3
import asyncio
import enum
import logging
import types
from collections import deque
from functools import wraps
from inspect import getcallargs, getfullargspec, signature
from .cursors import Cursor, DictCursor, SetCursor, ZCursor
from .encoders import BaseEncoder, UTF8Encoder
from .exceptions import (
ConnectionLostError,
Error,
ErrorReply,
NoRunningScriptError,
NotConnectedError,
ScriptKilledError,
TimeoutError,
TransactionError,
)
from .log import logger
from .replies import (
BlockingPopReply,
ClientListReply,
ConfigPairReply,
DictReply,
EvalScriptReply,
InfoReply,
ListReply,
PubSubReply,
SetReply,
StatusReply,
ZRangeReply,
)
try:
import hiredis
except ImportError:
hiredis = None
NoneType = type(None)
class _NoTransactionType(object):
"""
Instance of this object can be passed to a @_command when it's not part of
a transaction. We need this because we need a singleton which is different
from None. (None could be a valid input for a @_command, so there is no way
to see whether this would be an extra 'transaction' value.)
"""
_NoTransaction = _NoTransactionType()
class ZScoreBoundary:
"""Score boundary for a sorted set,
for queries like zrangebyscore and similar.
:param value: Value for the boundary.
:type value: float
:param exclude_boundary: Exclude the boundary.
:type exclude_boundary: bool
"""
def __init__(self, value, exclude_boundary=False):
assert isinstance(value, float) or value in ("+inf", "-inf")
self.value = value
self.exclude_boundary = exclude_boundary
def __repr__(self):
return "ZScoreBoundary(value=%r, exclude_boundary=%r)" % (
self.value,
self.exclude_boundary,
)
ZScoreBoundary.MIN_VALUE = ZScoreBoundary("-inf")
ZScoreBoundary.MAX_VALUE = ZScoreBoundary("+inf")
class ZAggregate(enum.Enum):
"""Aggregation method for zinterstore and zunionstore
"""
SUM = "SUM"
MIN = "MIN"
MAX = "MAX"
class PipelinedCall:
"""Track record for call that is being executed in a protocol
"""
__slots__ = ("cmd", "is_blocking")
def __init__(self, cmd, is_blocking):
self.cmd = cmd
self.is_blocking = is_blocking
class MultiBulkReply:
"""Container for a multi bulk reply.
"""
def __init__(self, protocol, count):
#: Buffer of incoming, undelivered data, received from the parser.
self._data_queue = []
#: Incoming read queries.
#: Contains (read_count, Future, decode_flag, one_only_flag) tuples.
self._f_queue = deque()
self.protocol = protocol
self.count = int(count)
def _feed_received(self, item):
"""Feed entry for the parser
"""
# Push received items on the queue
self._data_queue.append(item)
self._flush()
def _flush(self):
"""Answer read queries when we have enough data in our multibulk reply
"""
# As long as we have more data in our queue then we require for a read
# query -> answer queries.
while self._f_queue and self._f_queue[0][0] <= len(self._data_queue):
# Pop query.
count, f, decode, one_only = self._f_queue.popleft()
# Slice data buffer.
data, self._data_queue = self._data_queue[:count], self._data_queue[count:]
# When the decode flag is given, decode bytes to native types.
if decode:
data = [self._decode(d) for d in data]
# When one_only flag has been given, don't return an array.
if one_only:
assert len(data) == 1
f.set_result(data[0])
else:
f.set_result(data)
def _decode(self, result):
"""Decode bytes to native Python types
"""
if isinstance(result, (StatusReply, int, float, MultiBulkReply)):
# Note that MultiBulkReplies can be nested. e.g. in the 'scan' operation.
return result
elif isinstance(result, bytes):
return self.protocol.decode_to_native(result)
elif result is None:
return result
else:
raise AssertionError("Invalid type: %r" % type(result))
def _read(self, decode=True, count=1, _one=False):
"""Do read operation on the queue. Return future.
"""
f = asyncio.Future()
self._f_queue.append((count, f, decode, _one))
# If there is enough data on the queue, answer future immediately.
self._flush()
return f
def iter_raw(self):
"""Iterate over all multi bulk packets.
This yields futures that won't decode bytes yet.
"""
for i in range(self.count):
yield self._read(decode=False, _one=True)
def __iter__(self):
"""Iterate over the reply. This yields coroutines of the decoded packets.
It decodes bytes automatically using protocol.decode_to_native.
"""
for i in range(self.count):
yield self._read(_one=True)
def __repr__(self):
return "MultiBulkReply(protocol=%r, count=%r)" % (self.protocol, self.count)
class _ScanPart:
"""Internal: result chunk of a scan operation
"""
def __init__(self, new_cursor_pos, items):
self.new_cursor_pos = new_cursor_pos
self.items = items
class PostProcessors:
"""At the protocol level, we only know about a few basic classes; they
include: bool, int, StatusReply, MultiBulkReply and bytes.
This will return a postprocessor function that turns these into more
meaningful objects.
For some methods, we have several post processors. E.g. a list can be
returned either as a ListReply (which has some special streaming
functionality), but also as a Python list.
"""
@classmethod
def get_all(cls, return_type):
"""Return list of (suffix, return_type, post_processor)
"""
default = cls.get_default(return_type)
alternate = cls.get_alternate_post_processor(return_type)
result = [("", return_type, default)]
if alternate:
result.append(alternate)
return result
@classmethod
def get_default(cls, return_type):
"""Give post processor function for return type
"""
return {
ListReply: cls.multibulk_as_list,
SetReply: cls.multibulk_as_set,
DictReply: cls.multibulk_as_dict,
float: cls.bytes_to_float,
(float, NoneType): cls.bytes_to_float_or_none,
NativeType: cls.bytes_to_native,
(NativeType, NoneType): cls.bytes_to_native_or_none,
InfoReply: cls.bytes_to_info,
ClientListReply: cls.bytes_to_clientlist,
str: cls.bytes_to_str,
bool: cls.int_to_bool,
BlockingPopReply: cls.multibulk_as_blocking_pop_reply,
ZRangeReply: cls.multibulk_as_zrangereply,
StatusReply: cls.bytes_to_status_reply,
(StatusReply, NoneType): cls.bytes_to_status_reply_or_none,
int: None,
(int, NoneType): None,
ConfigPairReply: cls.multibulk_as_configpair,
ListOf(bool): cls.multibulk_as_boolean_list,
_ScanPart: cls.multibulk_as_scanpart,
EvalScriptReply: cls.any_to_evalscript,
NoneType: None,
}[return_type]
@classmethod
def get_alternate_post_processor(cls, return_type):
"""For list/set/dict. Create additional post processors that return
python classes rather than ListReply/SetReply/DictReply.
"""
original_post_processor = cls.get_default(return_type)
if return_type == ListReply:
async def as_list(protocol, result):
result = await original_post_processor(protocol, result)
return await result.aslist()
return "_aslist", list, as_list
elif return_type == SetReply:
async def as_set(protocol, result):
result = await original_post_processor(protocol, result)
return await result.asset()
return "_asset", set, as_set
elif return_type in (DictReply, ZRangeReply):
async def as_dict(protocol, result):
result = await original_post_processor(protocol, result)
return await result.asdict()
return "_asdict", dict, as_dict
# === Post processor handlers below. ===
async def multibulk_as_list(protocol, result):
assert isinstance(result, MultiBulkReply)
return ListReply(result)
async def multibulk_as_boolean_list(protocol, result):
# Turn the array of integers into booleans.
assert isinstance(result, MultiBulkReply)
values = await ListReply(result).aslist()
return [bool(v) for v in values]
async def multibulk_as_set(protocol, result):
assert isinstance(result, MultiBulkReply)
return SetReply(result)
async def multibulk_as_dict(protocol, result):
assert isinstance(result, MultiBulkReply)
return DictReply(result)
async def multibulk_as_zrangereply(protocol, result):
assert isinstance(result, MultiBulkReply)
return ZRangeReply(result)
async def multibulk_as_blocking_pop_reply(protocol, result):
if result is None:
raise TimeoutError("Timeout in blocking pop")
else:
assert isinstance(result, MultiBulkReply)
list_name, value = await ListReply(result).aslist()
return BlockingPopReply(list_name, value)
async def multibulk_as_configpair(protocol, result):
assert isinstance(result, MultiBulkReply)
parameter, value = await ListReply(result).aslist()
return ConfigPairReply(parameter, value)
async def multibulk_as_scanpart(protocol, result):
"""Process scanpart result.
This is a multibulk reply of length two, where the first item is the
new cursor position and the second item is a nested multi bulk reply
containing all the elements.
"""
# Get outer multi bulk reply.
assert isinstance(result, MultiBulkReply)
new_cursor_pos, items_bulk = await ListReply(result).aslist()
assert isinstance(items_bulk, MultiBulkReply)
# Read all items for scan chunk in memory. This is fine, because it's
# transmitted in chunks of about 10.
items = await ListReply(items_bulk).aslist()
return _ScanPart(int(new_cursor_pos), items)
async def bytes_to_info(protocol, result):
assert isinstance(result, bytes)
return InfoReply(result)
async def bytes_to_status_reply(protocol, result):
assert isinstance(result, bytes)
return StatusReply(result.decode("utf-8"))
async def bytes_to_status_reply_or_none(protocol, result):
assert isinstance(result, (bytes, NoneType))
if result:
return StatusReply(result.decode("utf-8"))
async def bytes_to_clientlist(protocol, result):
assert isinstance(result, bytes)
return ClientListReply(result)
async def int_to_bool(protocol, result):
assert isinstance(result, int)
return bool(result) # Convert int to bool
async def bytes_to_native(protocol, result):
assert isinstance(result, bytes)
return protocol.decode_to_native(result)
async def bytes_to_str(protocol, result):
assert isinstance(result, bytes)
return result.decode("ascii")
async def bytes_to_native_or_none(protocol, result):
if result is None:
return result
else:
assert isinstance(result, bytes)
return protocol.decode_to_native(result)
async def bytes_to_float_or_none(protocol, result):
if result is None:
return result
assert isinstance(result, bytes)
return float(result)
async def bytes_to_float(protocol, result):
assert isinstance(result, bytes)
return float(result)
async def any_to_evalscript(protocol, result):
# Result can be native, int, MultiBulkReply or even a nested structure
assert isinstance(result, (int, bytes, MultiBulkReply, NoneType))
return EvalScriptReply(protocol, result)
class ListOf:
"""Annotation helper for protocol methods
"""
def __init__(self, type_):
self.type = type_
def __repr__(self):
return "ListOf(%r)" % self.type
def __eq__(self, other):
return isinstance(other, ListOf) and other.type == self.type
def __hash__(self):
return hash((ListOf, self.type))
class NativeType:
"""Constant which represents the native Python type that's used
"""
def __new__(cls):
raise Exception("NativeType is not meant to be initialized.")
class CommandCreator:
"""Utility for creating a wrapper around the Redis protocol methods.
This will also do type checking.
This wrapper handles (optionally) post processing of the returned data and
implements some logic where commands behave different in case of a
transaction or pubsub.
Warning: We use the annotations of `method` extensively for type checking
and determining which post processor to choose.
"""
def __init__(self, method):
self.method = method
@property
def specs(self):
"""Argspecs
"""
return getfullargspec(self.method)
@property
def return_type(self):
"""Return type as defined in the method's annotation
"""
return self.specs.annotations.get("return", None)
@property
def params(self):
return {k: v for k, v in self.specs.annotations.items() if k != "return"}
@classmethod
def get_real_type(cls, protocol, type_):
"""Given a protocol instance, and type annotation, return something that
we can pass to isinstance for the typechecking.
"""
# If NativeType was given, replace it with the type of the protocol
# itself.
if isinstance(type_, tuple):
return tuple(cls.get_real_type(protocol, t) for t in type_)
if type_ == NativeType:
return protocol.native_type
elif isinstance(type_, ListOf):
return (
list,
types.GeneratorType,
) # We don't check the content of the list.
else:
return type_
def _create_input_typechecker(self):
"""Return function that does typechecking on input data
"""
params = self.params
if params:
def typecheck_input(protocol, *a, **kw):
"""Given a protocol instance and *a/**kw of this method, raise TypeError
when the signature doesn't match.
"""
if protocol.enable_typechecking:
# All @_command/@_query_command methods can take
# *optionally* a Transaction instance as first argument.
if a and isinstance(a[0], (Transaction, _NoTransactionType)):
a = a[1:]
for name, value in getcallargs(
self.method, None, _NoTransaction, *a, **kw
).items():
if name in params:
real_type = self.get_real_type(protocol, params[name])
if not isinstance(value, real_type):
raise TypeError(
"RedisProtocol.%s received %r, expected %r"
% (
self.method.__name__,
type(value).__name__,
real_type,
)
)
else:
def typecheck_input(protocol, *a, **kw):
pass
return typecheck_input
def _create_return_typechecker(self, return_type):
"""Return function that does typechecking on output data
"""
if return_type and not isinstance(
return_type, str
): # Exclude 'Transaction'/'Subscription' which are 'str'
def typecheck_return(protocol, result):
"""Given protocol and result value. Raise TypeError if the result is of
the wrong type.
"""
if protocol.enable_typechecking:
expected_type = self.get_real_type(protocol, return_type)
if not isinstance(result, expected_type):
raise TypeError(
f"Got unexpected return type {type(result).__name__!r} in "
f"RedisProtocol.{self.method.__name__}, "
f"expected {expected_type!r}"
)
else:
def typecheck_return(protocol, result):
pass
return typecheck_return
def _get_docstring(self, suffix, return_type):
# Append the real signature as the first line in the docstring.
# (This will make the sphinx docs show the real signature instead of
# (*a, **kw) of the wrapper.)
# (But don't put the annotations inside the copied signature, that's rather
# ugly in the docs.)
parameters = signature(self.method).parameters
# The below differs from tuple(parameters.keys()) as it preserves the
# * and ** prefixes of variadic arguments
argnames = tuple(str(p).split(":")[0] for p in parameters.values())
# Use function annotations to generate param documentation.
def get_name(type_):
"""Turn type annotation into doc string
"""
try:
return {
BlockingPopReply: ":class:`BlockingPopReply <asyncio_redis.replies.BlockingPopReply>`", # noqa: E501
ConfigPairReply: ":class:`ConfigPairReply <asyncio_redis.replies.ConfigPairReply>`", # noqa: E501
DictReply: ":class:`DictReply <asyncio_redis.replies.DictReply>`",
InfoReply: ":class:`InfoReply <asyncio_redis.replies.InfoReply>`",
ClientListReply: ":class:`InfoReply <asyncio_redis.replies.ClientListReply>`", # noqa: E501
ListReply: ":class:`ListReply <asyncio_redis.replies.ListReply>`",
MultiBulkReply: ":class:`MultiBulkReply <asyncio_redis.replies.MultiBulkReply>`", # noqa: E501
NativeType: "Native Python type, as defined by :attr:`~asyncio_redis.encoders.BaseEncoder.native_type`", # noqa: E501
NoneType: "None",
SetReply: ":class:`SetReply <asyncio_redis.replies.SetReply>`",
StatusReply: ":class:`StatusReply <asyncio_redis.replies.StatusReply>`", # noqa: E501
ZRangeReply: ":class:`ZRangeReply <asyncio_redis.replies.ZRangeReply>`", # noqa: E501
ZScoreBoundary: ":class:`ZScoreBoundary <asyncio_redis.ZScoreBoundary>`", # noqa: E501
EvalScriptReply: ":class:`EvalScriptReply <asyncio_redis.replies.EvalScriptReply>`", # noqa: E501
Cursor: ":class:`Cursor <asyncio_redis.cursors.Cursor>`",
SetCursor: ":class:`SetCursor <asyncio_redis.cursors.SetCursor>`",
DictCursor: ":class:`DictCursor <asyncio_redis.cursors.DictCursor>`", # noqa: E501
ZCursor: ":class:`ZCursor <asyncio_redis.cursors.ZCursor>`",
_ScanPart: ":class:`_ScanPart",
int: "int",
bool: "bool",
dict: "dict",
float: "float",
str: "str",
bytes: "bytes",
list: "list",
set: "set",
# Because of circular references, we cannot use the real types here.
"Transaction": ":class:`asyncio_redis.Transaction`",
"Subscription": ":class:`asyncio_redis.Subscription`",
"Script": ":class:`~asyncio_redis.Script`",
}[type_]
except KeyError:
if isinstance(type_, ListOf):
return "List or iterable of %s" % get_name(type_.type)
if isinstance(type_, tuple):
return " or ".join(get_name(t) for t in type_)
raise TypeError("Unknown annotation %r" % type_)
def get_param(k, v):
return ":param %s: %s\n" % (k, get_name(v))
params_str = [get_param(k, v) for k, v in self.params.items()]
returns = (
":returns: (Future of) %s\n" % get_name(return_type) if return_type else ""
)
return "%s(%s)\n%s\n\n%s%s" % (
self.method.__name__ + suffix,
", ".join(argnames),
self.method.__doc__,
"".join(params_str),
returns,
)
def get_methods(self):
"""Return all the methods to be used in the RedisProtocol class.
"""
return [("", self._get_wrapped_method(None, "", self.return_type))]
def _get_wrapped_method(self, post_process, suffix, return_type):
"""Return the wrapped method for use in the `RedisProtocol` class.
"""
typecheck_input = self._create_input_typechecker()
typecheck_return = self._create_return_typechecker(return_type)
method = self.method
# Wrap it into a check which allows this command to be run either
# directly on the protocol, outside of transactions or from the
# transaction object.
@wraps(method)
async def wrapper(protocol_self, *a, **kw):
if a and isinstance(a[0], (Transaction, _NoTransactionType)):
transaction = a[0]
a = a[1:]
else:
transaction = _NoTransaction
# When calling from a transaction
if transaction != _NoTransaction:
# In case of a transaction, we receive a Future from the command.
typecheck_input(protocol_self, *a, **kw)
future = await method(protocol_self, transaction, *a, **kw)
future2 = asyncio.Future()
# Typecheck the future when the result is available.
async def done(result):
if post_process:
result = await post_process(protocol_self, result)
typecheck_return(protocol_self, result)
future2.set_result(result)
loop = asyncio.get_event_loop()
future.add_done_callback(lambda f: loop.create_task(done(f.result())))
return future2
# When calling from a pubsub context
elif protocol_self.in_pubsub:
if not a or a[0] != protocol_self._subscription:
raise Error("Cannot run command inside pubsub subscription.")
else:
typecheck_input(protocol_self, *a[1:], **kw)
result = await method(protocol_self, _NoTransaction, *a[1:], **kw)
if post_process:
result = await post_process(protocol_self, result)
typecheck_return(protocol_self, result)
return result
else:
typecheck_input(protocol_self, *a, **kw)
result = await method(protocol_self, _NoTransaction, *a, **kw)
if post_process:
result = await post_process(protocol_self, result)
typecheck_return(protocol_self, result)
return result
wrapper.__doc__ = self._get_docstring(suffix, return_type)
return wrapper
class QueryCommandCreator(CommandCreator):
"""Like `CommandCreator`, but for methods registered with `_query_command`.
This are the methods that cause commands to be send to the server.
Most of the commands get a reply from the server that needs to be post
processed to get the right Python type. We inspect here the
'returns'-annotation to determine the correct post processor.
"""
def get_methods(self):
# (Some commands, e.g. those that return a ListReply can generate
# multiple protocol methods. One that does return the ListReply, but
# also one with the 'aslist' suffix that returns a Python list.)
all_post_processors = PostProcessors.get_all(self.return_type)
result = []
for suffix, return_type, post_processor in all_post_processors:
result.append(
(suffix, self._get_wrapped_method(post_processor, suffix, return_type))
)
return result
_SMALL_INTS = list(str(i).encode("ascii") for i in range(1000))
# List of all command methods.
_all_commands = []
class _command:
"""Mark method as command (to be passed through CommandCreator for the
creation of a protocol method)
"""
creator = CommandCreator
def __init__(self, method):
self.method = method
class _query_command(_command):
"""Mark method as query command: This will pass through QueryCommandCreator.
.. note::
Be sure to choose the correct 'returns'-annotation. This will automatically
determine the correct post processor function in :class:`PostProcessors`.
"""
creator = QueryCommandCreator
def __init__(self, method):
super().__init__(method)
class _RedisProtocolMeta(type):
"""Metaclass for `RedisProtocol` which applies the _command decorator.
"""
def __new__(cls, name, bases, attrs):
for attr_name, value in dict(attrs).items():
if isinstance(value, _command):
creator = value.creator(value.method)
for suffix, method in creator.get_methods():
attrs[attr_name + suffix] = method
# Register command.
_all_commands.append(attr_name + suffix)
return type.__new__(cls, name, bases, attrs)
class RedisProtocol(asyncio.Protocol, metaclass=_RedisProtocolMeta):
"""The Redis Protocol implementation.
::
loop = asyncio.get_event_loop()
transport, protocol = await loop.create_connection(
RedisProtocol, 'localhost', 6379
)
:param password:
Redis database password
:type password:
Native Python type as defined by the ``encoder`` parameter
:param encoder:
Encoder to use for encoding to or decoding from redis bytes to a native type.
(Defaults to :class:`~asyncio_redis.encoders.UTF8Encoder`)
:type encoder:
:class:`~asyncio_redis.encoders.BaseEncoder` instance.
:param int db:
Redis database
:param bool enable_typechecking:
When True, check argument types for all redis commands.
Normally you want to have this enabled.
"""
def __init__(
self,
*,
password=None,
db=0,
encoder=None,
connection_lost_callback=None,
enable_typechecking=True,
):
if encoder is None:
encoder = UTF8Encoder()
assert isinstance(db, int)
assert isinstance(encoder, BaseEncoder)
assert encoder.native_type, "Encoder.native_type not defined"
assert not password or isinstance(password, encoder.native_type)
self.password = password
self.db = db
self._connection_lost_callback = connection_lost_callback
# Take encode / decode settings from encoder
self.encode_from_native = encoder.encode_from_native
self.decode_to_native = encoder.decode_to_native
self.native_type = encoder.native_type
self.enable_typechecking = enable_typechecking
self.transport = None
self._queue = deque() # Input parser queues
self._messages_queue = None # Pubsub queue
self._is_connected = (
False # True as long as the underlying transport is connected.
)
# Pubsub state
self._in_pubsub = False
self._subscription = None
self._pubsub_channels = set() # Set of channels
self._pubsub_patterns = set() # Set of patterns
# Transaction related stuff.
self._transaction_lock = asyncio.Lock()
self._transaction = None
self._transaction_response_queue = None # Transaction answer queue
self._line_received_handlers = {
b"+": self._handle_status_reply,
b"-": self._handle_error_reply,
b"$": self._handle_bulk_reply,
b"*": self._handle_multi_bulk_reply,
b":": self._handle_int_reply,
}
def connection_made(self, transport):
self.transport = transport
self._is_connected = True
logger.log(logging.INFO, "Redis connection made")
# Pipelined calls
self._pipelined_calls = set() # Set of all the pipelined calls.
# Start parsing reader stream.
self._reader = asyncio.StreamReader()
self._reader.set_transport(transport)
self._reader_f = asyncio.get_event_loop().create_task(self._reader_coroutine())
async def initialize():
# If a password or database was been given, first connect to that one.
if self.password:
await self.auth(self.password)
if self.db:
await self.select(self.db)
# If we are in pubsub mode, send channel subscriptions again.
if self._in_pubsub:
if self._pubsub_channels:
await self._subscribe(
self._subscription, list(self._pubsub_channels)
) # TODO: unittest this
if self._pubsub_patterns:
await self._psubscribe(
self._subscription, list(self._pubsub_patterns)
)
loop = asyncio.get_event_loop()
loop.create_task(initialize())
def data_received(self, data):
"""Process data received from Redis server
"""
self._reader.feed_data(data)
def _encode_int(self, value: int) -> bytes:
"""Encodes an integer to bytes. (always ascii)
"""
if 0 < value < 1000: # For small values, take pre-encoded string.
return _SMALL_INTS[value]
else:
return str(value).encode("ascii")
def _encode_float(self, value: float) -> bytes:
"""Encodes a float to bytes. (always ascii)
"""
return str(value).encode("ascii")
def _encode_zscore_boundary(self, value: ZScoreBoundary) -> str:
"""Encodes a zscore boundary. (always ascii)
"""
if isinstance(value.value, str):
return str(value.value).encode("ascii") # +inf and -inf
elif value.exclude_boundary:
return str("(%f" % value.value).encode("ascii")
else:
return str("%f" % value.value).encode("ascii")
def eof_received(self):
logger.log(logging.INFO, "EOF received in RedisProtocol")
self._reader.feed_eof()
def connection_lost(self, exc):
if exc is None:
self._reader.feed_eof()
else:
logger.info("Connection lost with exec: %s" % exc)
self._reader.set_exception(exc)
if self._reader_f:
self._reader_f.cancel()
self._is_connected = False
self.transport = None
self._reader = None
self._reader_f = None
# Raise exception on all waiting futures.
while self._queue:
f = self._queue.popleft()
if not f.cancelled():
f.set_exception(ConnectionLostError(exc))
logger.log(logging.INFO, "Redis connection lost")
# Call connection_lost callback
if self._connection_lost_callback:
self._connection_lost_callback()
# Request state
@property
def in_blocking_call(self):
"""True when waiting for answer to blocking command
"""
return any(c.is_blocking for c in self._pipelined_calls)
@property
def in_pubsub(self):
"""True when the protocol is in pubsub mode
"""
return self._in_pubsub
@property
def in_transaction(self):
"""True when we're inside a transaction
"""
return bool(self._transaction)
@property
def in_use(self):
"""True when this protocol is in use
"""
return self.in_blocking_call or self.in_pubsub or self.in_transaction
@property
def is_connected(self):
"""True when the underlying transport is connected
"""
return self._is_connected
# Handle replies
async def _reader_coroutine(self):
"""Coroutine which reads input from the stream reader and processes it.
"""
while True:
try:
await self._handle_item(self._push_answer)
except ConnectionLostError:
return
except asyncio.IncompleteReadError:
return
async def _handle_item(self, cb):
c = await self._reader.readexactly(1)
if c:
await self._line_received_handlers[c](cb)
else:
raise ConnectionLostError(None)
async def _handle_status_reply(self, cb):
line = (await self._reader.readline()).rstrip(b"\r\n")
cb(line)
async def _handle_int_reply(self, cb):
line = (await self._reader.readline()).rstrip(b"\r\n")
cb(int(line))
async def _handle_error_reply(self, cb):
line = (await self._reader.readline()).rstrip(b"\r\n")
cb(ErrorReply(line.decode("ascii")))
async def _handle_bulk_reply(self, cb):
length = int((await self._reader.readline()).rstrip(b"\r\n"))
if length == -1:
# None bulk reply
cb(None)
else:
# Read data
data = await self._reader.readexactly(length)
cb(data)
# Ignore trailing newline.
remaining = await self._reader.readline()
assert remaining.rstrip(b"\r\n") == b""
async def _handle_multi_bulk_reply(self, cb):
# NOTE: the reason for passing the callback `cb` in here is
# mainly because we want to return the result object
# especially in this case before the input is read
# completely. This allows a streaming API.
count = int((await self._reader.readline()).rstrip(b"\r\n"))
# Handle multi-bulk none.
# (Used when a transaction exec fails.)
if count == -1:
cb(None)
return
reply = MultiBulkReply(self, count)
# Return the empty queue immediately as an answer.
if self._in_pubsub:
loop = asyncio.get_event_loop()
loop.create_task(self._handle_pubsub_multibulk_reply(reply))
else:
cb(reply)
# Wait for all multi bulk reply content.
for i in range(count):
await self._handle_item(reply._feed_received)
async def _handle_pubsub_multibulk_reply(self, multibulk_reply):
# Read first item of the multi bulk reply raw.
type = await multibulk_reply._read(decode=False, _one=True)
assert type in (
b"message",
b"subscribe",
b"unsubscribe",
b"pmessage",
b"psubscribe",
b"punsubscribe",
)
if type == b"message":
channel, value = await multibulk_reply._read(count=2)
await self._subscription._messages_queue.put(PubSubReply(channel, value))
elif type == b"pmessage":
pattern, channel, value = await multibulk_reply._read(count=3)
await self._subscription._messages_queue.put(
PubSubReply(channel, value, pattern=pattern)
)
# We can safely ignore 'subscribe'/'unsubscribe' replies at this point,
# they don't contain anything really useful.
# Redis operations.
def _send_command(self, args):
"""Send Redis request command.
`args` should be a list of bytes to be written to the transport.
"""
# Create write buffer.
data = []
# NOTE: First, I tried to optimize by also flushing this buffer in
# between the looping through the args. However, I removed that as the
# advantage was really small. Even when some commands like `hmset`
# could accept a generator instead of a list/dict, we would need to
# read out the whole generator in memory in order to write the number
# of arguments first.
# Serialize and write header (number of arguments.)
data += [b"*", self._encode_int(len(args)), b"\r\n"]
# Write arguments.
for arg in args:
data += [b"$", self._encode_int(len(arg)), b"\r\n", arg, b"\r\n"]
# Flush the last part
self.transport.write(b"".join(data))
async def _get_answer(
self, transaction, answer_f, _bypass=False, call=None
): # XXX: rename _bypass to not_queued
"""Return an answer to the pipelined query.
(Or when we are in a transaction, return a future for the answer.)
"""
# Wait for the answer to come in
result = await answer_f
if transaction != _NoTransaction and not _bypass:
# When the connection is inside a transaction, the query will be queued.
if result != b"QUEUED":
raise Error(
"Expected to receive QUEUED for query in transaction, received %r."
% result
)
# Return a future which will contain the result when it arrives.
f = asyncio.Future()
self._transaction_response_queue.append((f, call))
return f
else:
if call:
self._pipelined_calls.remove(call)
return result
def _push_answer(self, answer):
"""Answer future at the queue.
"""
f = self._queue.popleft()
if isinstance(answer, Exception):
f.set_exception(answer)
elif f.cancelled():
# Received an answer from Redis, for a query which `Future` got
# already cancelled. Don't call set_result, that would raise an
# `InvalidStateError` otherwise.
pass
else:
f.set_result(answer)
async def _query(self, transaction, *args, _bypass=False, set_blocking=False):
"""Wrapper around both _send_command and _get_answer.
Coroutine that sends the query to the server, and returns the reply.
(Where the reply is a simple Redis type: these are `int`,
`StatusReply`, `bytes` or `MultiBulkReply`) When we are in a transaction,
this coroutine will return a `Future` of the actual result.
"""
assert transaction == _NoTransaction or isinstance(transaction, Transaction)
if not self._is_connected:
raise NotConnectedError
# Get lock.
if transaction == _NoTransaction:
await self._transaction_lock.acquire()
else:
assert transaction == self._transaction
try:
call = PipelinedCall(args[0], set_blocking)
self._pipelined_calls.add(call)
# Add a new future to our answer queue.
answer_f = asyncio.Future()
self._queue.append(answer_f)
# Send command
self._send_command(args)
finally:
# Release lock.
if transaction == _NoTransaction:
self._transaction_lock.release()
# TODO: when set_blocking=True, only release lock after reading the answer.
# (it doesn't make sense to free the input and pipeline commands in
# that case.)
# Receive answer.
result = await self._get_answer(
transaction, answer_f, _bypass=_bypass, call=call
)
return result
# Internal
@_query_command
def auth(self, tr, password: NativeType) -> StatusReply:
"""Authenticate to the server
"""
self.password = password
return self._query(tr, b"auth", self.encode_from_native(password))
@_query_command
def select(self, tr, db: int) -> StatusReply:
"""Change the selected database for the current connection
"""
self.db = db
return self._query(tr, b"select", self._encode_int(db))
# Strings
@_query_command
def set(
self,
tr,
key: NativeType,
value: NativeType,
expire: (int, NoneType) = None,
pexpire: (int, NoneType) = None,
only_if_not_exists: bool = False,
only_if_exists: bool = False,
) -> (StatusReply, NoneType):
"""Set the string value of a key
::
await protocol.set('key', 'value')
result = await protocol.get('key')
assert result == 'value'
To set a value and its expiration, only if key not exists, do:
::
await protocol.set('key', 'value', expire=1, only_if_not_exists=True)
This will send: ``SET key value EX 1 NX`` at the network.
To set value and its expiration in milliseconds, but only if key already exists:
::
await protocol.set('key', 'value', pexpire=1000, only_if_exists=True)
"""
params = [b"set", self.encode_from_native(key), self.encode_from_native(value)]
if expire is not None:
params.extend((b"ex", self._encode_int(expire)))
if pexpire is not None:
params.extend((b"px", self._encode_int(pexpire)))
if only_if_not_exists and only_if_exists:
raise ValueError(
"only_if_not_exists and only_if_exists cannot be true simultaniously"
)
if only_if_not_exists:
params.append(b"nx")
if only_if_exists:
params.append(b"xx")
return self._query(tr, *params)
@_query_command
def setex(
self, tr, key: NativeType, seconds: int, value: NativeType
) -> StatusReply:
"""Set the string value of a key with expire
"""
return self._query(
tr,
b"setex",
self.encode_from_native(key),
self._encode_int(seconds),
self.encode_from_native(value),
)
@_query_command
def setnx(self, tr, key: NativeType, value: NativeType) -> bool:
"""Set the string value of a key if it does not exist.
Returns True if value is successfully set.
"""
return self._query(
tr, b"setnx", self.encode_from_native(key), self.encode_from_native(value)
)
@_query_command
def get(self, tr, key: NativeType) -> (NativeType, NoneType):
"""Get the value of a key
"""
return self._query(tr, b"get", self.encode_from_native(key))
@_query_command
def mget(self, tr, keys: ListOf(NativeType)) -> ListReply:
"""Return the values of all specified keys
"""
return self._query(tr, b"mget", *map(self.encode_from_native, keys))
@_query_command
def strlen(self, tr, key: NativeType) -> int:
"""Return the length of the string value stored at key. An error is
returned when key holds a non-string .
"""
return self._query(tr, b"strlen", self.encode_from_native(key))
@_query_command
def append(self, tr, key: NativeType, value: NativeType) -> int:
"""Append a value to a key
"""
return self._query(
tr, b"append", self.encode_from_native(key), self.encode_from_native(value)
)
@_query_command
def getset(self, tr, key: NativeType, value: NativeType) -> (NativeType, NoneType):
"""Set the string value of a key and return its old value
"""
return self._query(
tr, b"getset", self.encode_from_native(key), self.encode_from_native(value)
)
@_query_command
def incr(self, tr, key: NativeType) -> int:
"""Increment the integer value of a key by one
"""
return self._query(tr, b"incr", self.encode_from_native(key))
@_query_command
def incrby(self, tr, key: NativeType, increment: int) -> int:
"""Increment the integer value of a key by the given amount
"""
return self._query(
tr, b"incrby", self.encode_from_native(key), self._encode_int(increment)
)
@_query_command
def decr(self, tr, key: NativeType) -> int:
"""Decrement the integer value of a key by one
"""
return self._query(tr, b"decr", self.encode_from_native(key))
@_query_command
def decrby(self, tr, key: NativeType, increment: int) -> int:
"""Decrement the integer value of a key by the given number
"""
return self._query(
tr, b"decrby", self.encode_from_native(key), self._encode_int(increment)
)
@_query_command
def randomkey(self, tr) -> NativeType:
"""Return a random key from the keyspace
"""
return self._query(tr, b"randomkey")
@_query_command
def exists(self, tr, key: NativeType) -> bool:
"""Determine if a key exists
"""
return self._query(tr, b"exists", self.encode_from_native(key))
@_query_command
def delete(self, tr, keys: ListOf(NativeType)) -> int:
"""Delete a key
"""
return self._query(tr, b"del", *map(self.encode_from_native, keys))
@_query_command
def move(self, tr, key: NativeType, database: int) -> int:
"""Move a key to another database
"""
return self._query(
tr, b"move", self.encode_from_native(key), self._encode_int(database)
) # TODO: unittest
@_query_command
def rename(self, tr, key: NativeType, newkey: NativeType) -> StatusReply:
"""Rename a key
"""
return self._query(
tr, b"rename", self.encode_from_native(key), self.encode_from_native(newkey)
)
@_query_command
def renamenx(self, tr, key: NativeType, newkey: NativeType) -> int:
"""Rename a key, only if the new key does not exist
Returns 1 if the key was successfully renamed.
"""
return self._query(
tr,
b"renamenx",
self.encode_from_native(key),
self.encode_from_native(newkey),
)
@_query_command
def bitop_and(self, tr, destkey: NativeType, srckeys: ListOf(NativeType)) -> int:
"""Perform a bitwise AND operation between multiple keys
"""
return self._bitop(tr, b"and", destkey, srckeys)
@_query_command
def bitop_or(self, tr, destkey: NativeType, srckeys: ListOf(NativeType)) -> int:
"""Perform a bitwise OR operation between multiple keys
"""
return self._bitop(tr, b"or", destkey, srckeys)
@_query_command
def bitop_xor(self, tr, destkey: NativeType, srckeys: ListOf(NativeType)) -> int:
"""Perform a bitwise XOR operation between multiple keys
"""
return self._bitop(tr, b"xor", destkey, srckeys)
def _bitop(self, tr, op, destkey, srckeys):
return self._query(
tr,
b"bitop",
op,
self.encode_from_native(destkey),
*map(self.encode_from_native, srckeys),
)
@_query_command
def bitop_not(self, tr, destkey: NativeType, key: NativeType) -> int:
"""Perform a bitwise NOT operation between multiple keys
"""
return self._query(
tr,
b"bitop",
b"not",
self.encode_from_native(destkey),
self.encode_from_native(key),
)
@_query_command
def bitcount(self, tr, key: NativeType, start: int = 0, end: int = -1) -> int:
"""Count the number of set bits (population counting) in a string
"""
return self._query(
tr,
b"bitcount",
self.encode_from_native(key),
self._encode_int(start),
self._encode_int(end),
)
@_query_command
def getbit(self, tr, key: NativeType, offset: int) -> bool:
"""Returns the bit value at offset in the string value stored at key
"""
return self._query(
tr, b"getbit", self.encode_from_native(key), self._encode_int(offset)
)
@_query_command
def setbit(self, tr, key: NativeType, offset: int, value: bool) -> bool:
"""Sets or clears the bit at offset in the string value stored at key
"""
return self._query(
tr,
b"setbit",
self.encode_from_native(key),
self._encode_int(offset),
self._encode_int(int(value)),
)
# Keys
@_query_command
def keys(self, tr, pattern: NativeType) -> ListReply:
"""Find all keys matching the given pattern.
.. note:: Also take a look at :func:`~asyncio_redis.RedisProtocol.scan`.
"""
return self._query(tr, b"keys", self.encode_from_native(pattern))
@_query_command
def expire(self, tr, key: NativeType, seconds: int) -> int:
"""Set a key's time to live in seconds
"""
return self._query(
tr, b"expire", self.encode_from_native(key), self._encode_int(seconds)
)
@_query_command
def pexpire(self, tr, key: NativeType, milliseconds: int) -> int:
"""Set a key's time to live in milliseconds
"""
return self._query(
tr, b"pexpire", self.encode_from_native(key), self._encode_int(milliseconds)
)
@_query_command
def expireat(self, tr, key: NativeType, timestamp: int) -> int:
"""Set the expiration for a key as a UNIX timestamp
"""
return self._query(
tr, b"expireat", self.encode_from_native(key), self._encode_int(timestamp)
)
@_query_command
def pexpireat(self, tr, key: NativeType, milliseconds_timestamp: int) -> int:
"""Set the expiration for a key as a UNIX timestamp specified in milliseconds
"""
return self._query(
tr,
b"pexpireat",
self.encode_from_native(key),
self._encode_int(milliseconds_timestamp),
)
@_query_command
def persist(self, tr, key: NativeType) -> int:
"""Remove the expiration from a key
"""
return self._query(tr, b"persist", self.encode_from_native(key))
@_query_command
def ttl(self, tr, key: NativeType) -> int:
"""Get the time to live for a key
"""
return self._query(tr, b"ttl", self.encode_from_native(key))
@_query_command
def pttl(self, tr, key: NativeType) -> int:
"""Get the time to live for a key in milliseconds
"""
return self._query(tr, b"pttl", self.encode_from_native(key))
# Set operations
@_query_command
def sadd(self, tr, key: NativeType, members: ListOf(NativeType)) -> int:
"""Add one or more members to a set
"""
return self._query(
tr,
b"sadd",
self.encode_from_native(key),
*map(self.encode_from_native, members),
)
@_query_command
def srem(self, tr, key: NativeType, members: ListOf(NativeType)) -> int:
"""Remove one or more members from a set
"""
return self._query(
tr,
b"srem",
self.encode_from_native(key),
*map(self.encode_from_native, members),
)
@_query_command
def spop(self, tr, key: NativeType) -> (NativeType, NoneType):
"""Remove and return a random element from the set value stored at key
"""
return self._query(tr, b"spop", self.encode_from_native(key))
@_query_command
def srandmember(self, tr, key: NativeType, count: int = 1) -> SetReply:
"""Get one or multiple random members from a set.
Return a list of members, even when count==1.
"""
return self._query(
tr, b"srandmember", self.encode_from_native(key), self._encode_int(count)
)
@_query_command
def sismember(self, tr, key: NativeType, value: NativeType) -> bool:
"""Determine if a given value is a member of a set
"""
return self._query(
tr,
b"sismember",
self.encode_from_native(key),
self.encode_from_native(value),
)
@_query_command
def scard(self, tr, key: NativeType) -> int:
"""Get the number of members in a set
"""
return self._query(tr, b"scard", self.encode_from_native(key))
@_query_command
def smembers(self, tr, key: NativeType) -> SetReply:
"""Get all the members in a set
"""
return self._query(tr, b"smembers", self.encode_from_native(key))
@_query_command
def sinter(self, tr, keys: ListOf(NativeType)) -> SetReply:
"""Intersect multiple sets
"""
return self._query(tr, b"sinter", *map(self.encode_from_native, keys))
@_query_command
def sinterstore(self, tr, destination: NativeType, keys: ListOf(NativeType)) -> int:
"""Intersect multiple sets and store the resulting set in a key
"""
return self._query(
tr,
b"sinterstore",
self.encode_from_native(destination),
*map(self.encode_from_native, keys),
)
@_query_command
def sdiff(self, tr, keys: ListOf(NativeType)) -> SetReply:
"""Subtract multiple sets
"""
return self._query(tr, b"sdiff", *map(self.encode_from_native, keys))
@_query_command
def sdiffstore(self, tr, destination: NativeType, keys: ListOf(NativeType)) -> int:
"""Subtract multiple sets and store the resulting set in a key
"""
return self._query(
tr,
b"sdiffstore",
self.encode_from_native(destination),
*map(self.encode_from_native, keys),
)
@_query_command
def sunion(self, tr, keys: ListOf(NativeType)) -> SetReply:
"""Add multiple sets
"""
return self._query(tr, b"sunion", *map(self.encode_from_native, keys))
@_query_command
def sunionstore(self, tr, destination: NativeType, keys: ListOf(NativeType)) -> int:
"""Add multiple sets and store the resulting set in a key
"""
return self._query(
tr,
b"sunionstore",
self.encode_from_native(destination),
*map(self.encode_from_native, keys),
)
@_query_command
def smove(
self, tr, source: NativeType, destination: NativeType, value: NativeType
) -> int:
"""Move a member from one set to another
"""
return self._query(
tr,
b"smove",
self.encode_from_native(source),
self.encode_from_native(destination),
self.encode_from_native(value),
)
# List operations
@_query_command
def lpush(self, tr, key: NativeType, values: ListOf(NativeType)) -> int:
"""Prepend one or multiple values to a list
"""
return self._query(
tr,
b"lpush",
self.encode_from_native(key),
*map(self.encode_from_native, values),
)
@_query_command
def lpushx(self, tr, key: NativeType, value: NativeType) -> int:
"""Prepend a value to a list, only if the list exists
"""
return self._query(
tr, b"lpushx", self.encode_from_native(key), self.encode_from_native(value)
)
@_query_command
def rpush(self, tr, key: NativeType, values: ListOf(NativeType)) -> int:
"""Append one or multiple values to a list
"""
return self._query(
tr,
b"rpush",
self.encode_from_native(key),
*map(self.encode_from_native, values),
)
@_query_command
def rpushx(self, tr, key: NativeType, value: NativeType) -> int:
"""Append a value to a list, only if the list exists
"""
return self._query(
tr, b"rpushx", self.encode_from_native(key), self.encode_from_native(value)
)
@_query_command
def llen(self, tr, key: NativeType) -> int:
"""Return the length of the list stored at key
"""
return self._query(tr, b"llen", self.encode_from_native(key))
@_query_command
def lrem(self, tr, key: NativeType, count: int = 0, value="") -> int:
"""Remove elements from a list
"""
return self._query(
tr,
b"lrem",
self.encode_from_native(key),
self._encode_int(count),
self.encode_from_native(value),
)
@_query_command
def lrange(self, tr, key, start: int = 0, stop: int = -1) -> ListReply:
"""Get a range of elements from a list
"""
return self._query(
tr,
b"lrange",
self.encode_from_native(key),
self._encode_int(start),
self._encode_int(stop),
)
@_query_command
def ltrim(self, tr, key: NativeType, start: int = 0, stop: int = -1) -> StatusReply:
"""Trim a list to the specified range
"""
return self._query(
tr,
b"ltrim",
self.encode_from_native(key),
self._encode_int(start),
self._encode_int(stop),
)
@_query_command
def lpop(self, tr, key: NativeType) -> (NativeType, NoneType):
"""Remove and get the first element in a list
"""
return self._query(tr, b"lpop", self.encode_from_native(key))
@_query_command
def rpop(self, tr, key: NativeType) -> (NativeType, NoneType):
"""Remove and get the last element in a list
"""
return self._query(tr, b"rpop", self.encode_from_native(key))
@_query_command
def rpoplpush(
self, tr, source: NativeType, destination: NativeType
) -> (NativeType, NoneType):
"""Remove the last element in a list, append it to another list and return it
"""
return self._query(
tr,
b"rpoplpush",
self.encode_from_native(source),
self.encode_from_native(destination),
)
@_query_command
def lindex(self, tr, key: NativeType, index: int) -> (NativeType, NoneType):
"""Get an element from a list by its index
"""
return self._query(
tr, b"lindex", self.encode_from_native(key), self._encode_int(index)
)
@_query_command
def blpop(self, tr, keys: ListOf(NativeType), timeout: int = 0) -> BlockingPopReply:
"""Remove and get the first element in a list, or block until one is available.
This will raise :class:`~asyncio_redis.TimeoutError` when
the timeout was exceeded and Redis returns `None`.
"""
return self._blocking_pop(tr, b"blpop", keys, timeout=timeout)
@_query_command
def brpop(self, tr, keys: ListOf(NativeType), timeout: int = 0) -> BlockingPopReply:
"""Remove and get the last element in a list, or block until one is available.
This will raise :class:`~asyncio_redis.TimeoutError` when
the timeout was exceeded and Redis returns `None`.
"""
return self._blocking_pop(tr, b"brpop", keys, timeout=timeout)
def _blocking_pop(self, tr, command, keys, timeout: int = 0):
return self._query(
tr,
command,
*([self.encode_from_native(k) for k in keys] + [self._encode_int(timeout)]),
set_blocking=True,
)
@_command
async def brpoplpush(
self, tr, source: NativeType, destination: NativeType, timeout: int = 0
) -> NativeType:
"""Pop a value from a list, push it to another list and return it, or block
until one is available
"""
result = await self._query(
tr,
b"brpoplpush",
self.encode_from_native(source),
self.encode_from_native(destination),
self._encode_int(timeout),
set_blocking=True,
)
if result is None:
raise TimeoutError("Timeout in brpoplpush")
else:
assert isinstance(result, bytes)
return self.decode_to_native(result)
@_query_command
def lset(self, tr, key: NativeType, index: int, value: NativeType) -> StatusReply:
"""Set the value of an element in a list by its index
"""
return self._query(
tr,
b"lset",
self.encode_from_native(key),
self._encode_int(index),
self.encode_from_native(value),
)
@_query_command
def linsert(
self, tr, key: NativeType, pivot: NativeType, value: NativeType, before=False
) -> int:
"""Insert an element before or after another element in a list
"""
return self._query(
tr,
b"linsert",
self.encode_from_native(key),
(b"BEFORE" if before else b"AFTER"),
self.encode_from_native(pivot),
self.encode_from_native(value),
)
# Sorted Sets
@_query_command
def zadd(
self,
tr,
key: NativeType,
values: dict,
only_if_not_exists=False,
only_if_exists=False,
return_num_changed=False,
) -> int:
"""Add one or more members to a sorted set, or update its score if it already
exists
::
await protocol.zadd('myzset', { 'key': 4, 'key2': 5 })
"""
options = []
assert not (only_if_not_exists and only_if_exists)
if only_if_not_exists:
options.append(b"NX")
elif only_if_exists:
options.append(b"XX")
if return_num_changed:
options.append(b"CH")
data = []
for k, score in values.items():
assert isinstance(k, self.native_type)
assert isinstance(score, (int, float))
data.append(self._encode_float(score))
data.append(self.encode_from_native(k))
return self._query(tr, b"zadd", self.encode_from_native(key), *(options + data))
@_query_command
def zpopmin(self, tr, key: NativeType, count: int = 1) -> ZRangeReply:
"""Return the specified numbers of first elements from sorted set with a minimum
score.
You can do the following to recieve the slice of the sorted set as a
python dict (mapping the keys to their scores):
::
result = yield protocol.zpopmin('myzset', count=10)
my_dict = yield result.asdict()
"""
return self._query(
tr, b"zpopmin", self.encode_from_native(key), self._encode_int(count)
)
@_query_command
def zrange(
self, tr, key: NativeType, start: int = 0, stop: int = -1
) -> ZRangeReply:
"""Return a range of members in a sorted set, by index.
You can do the following to receive the slice of the sorted set as a
python dict (mapping the keys to their scores):
::
result = yield protocol.zrange('myzset', start=10, stop=20)
my_dict = yield result.asdict()
or the following to retrieve it as a list of keys:
::
result = yield protocol.zrange('myzset', start=10, stop=20)
my_dict = yield result.aslist()
"""
return self._query(
tr,
b"zrange",
self.encode_from_native(key),
self._encode_int(start),
self._encode_int(stop),
b"withscores",
)
@_query_command
def zrangebylex(self, tr, key: NativeType, start: str, stop: str) -> SetReply:
"""Return a range of members in a sorted set, by index.
You can do the following to receive the slice of the sorted set as a
python dict (mapping the keys to their scores):
::
result = yield protocol.zrangebykex('myzset', start='-', stop='[c')
my_dict = yield result.asdict()
or the following to retrieve it as a list of keys:
::
result = yield protocol.zrangebylex('myzset', start='-', stop='[c')
my_dict = yield result.aslist()
"""
return self._query(
tr,
b"zrangebylex",
self.encode_from_native(key),
self.encode_from_native(start),
self.encode_from_native(stop),
)
@_query_command
def zrevrange(
self, tr, key: NativeType, start: int = 0, stop: int = -1
) -> ZRangeReply:
"""Return a range of members in a reversed sorted set, by index.
You can do the following to receive the slice of the sorted set as a
python dict (mapping the keys to their scores):
::
my_dict = yield protocol.zrevrange_asdict('myzset', start=10, stop=20)
or the following to retrieve it as a list of keys:
::
zrange_reply = yield protocol.zrevrange('myzset', start=10, stop=20)
my_dict = yield zrange_reply.aslist()
"""
return self._query(
tr,
b"zrevrange",
self.encode_from_native(key),
self._encode_int(start),
self._encode_int(stop),
b"withscores",
)
@_query_command
def zrangebyscore(
self,
tr,
key: NativeType,
min: ZScoreBoundary = ZScoreBoundary.MIN_VALUE,
max: ZScoreBoundary = ZScoreBoundary.MAX_VALUE,
offset: int = 0,
limit: int = -1,
) -> ZRangeReply:
"""Return a range of members in a sorted set, by score
"""
return self._query(
tr,
b"zrangebyscore",
self.encode_from_native(key),
self._encode_zscore_boundary(min),
self._encode_zscore_boundary(max),
b"limit",
self._encode_int(offset),
self._encode_int(limit),
b"withscores",
)
@_query_command
def zrevrangebyscore(
self,
tr,
key: NativeType,
max: ZScoreBoundary = ZScoreBoundary.MAX_VALUE,
min: ZScoreBoundary = ZScoreBoundary.MIN_VALUE,
offset: int = 0,
limit: int = -1,
) -> ZRangeReply:
"""Return a range of members in a sorted set, by score, with scores ordered from
high to low
"""
return self._query(
tr,
b"zrevrangebyscore",
self.encode_from_native(key),
self._encode_zscore_boundary(max),
self._encode_zscore_boundary(min),
b"limit",
self._encode_int(offset),
self._encode_int(limit),
b"withscores",
)
@_query_command
def zremrangebyscore(
self,
tr,
key: NativeType,
min: ZScoreBoundary = ZScoreBoundary.MIN_VALUE,
max: ZScoreBoundary = ZScoreBoundary.MAX_VALUE,
) -> int:
"""Remove all members in a sorted set within the given scores
"""
return self._query(
tr,
b"zremrangebyscore",
self.encode_from_native(key),
self._encode_zscore_boundary(min),
self._encode_zscore_boundary(max),
)
@_query_command
def zremrangebyrank(self, tr, key: NativeType, min: int = 0, max: int = -1) -> int:
"""Remove all members in a sorted set within the given indexes
"""
return self._query(
tr,
b"zremrangebyrank",
self.encode_from_native(key),
self._encode_int(min),
self._encode_int(max),
)
@_query_command
def zcount(
self, tr, key: NativeType, min: ZScoreBoundary, max: ZScoreBoundary
) -> int:
"""Count the members in a sorted set with scores within the given values
"""
return self._query(
tr,
b"zcount",
self.encode_from_native(key),
self._encode_zscore_boundary(min),
self._encode_zscore_boundary(max),
)
@_query_command
def zscore(self, tr, key: NativeType, member: NativeType) -> (float, NoneType):
"""Get the score associated with the given member in a sorted set
"""
return self._query(
tr, b"zscore", self.encode_from_native(key), self.encode_from_native(member)
)
@_query_command
def zunionstore(
self,
tr,
destination: NativeType,
keys: ListOf(NativeType),
weights: (NoneType, ListOf(float)) = None,
aggregate=ZAggregate.SUM,
) -> int:
"""Add multiple sorted sets and store the resulting sorted set in a new key
"""
return self._zstore(tr, b"zunionstore", destination, keys, weights, aggregate)
@_query_command
def zinterstore(
self,
tr,
destination: NativeType,
keys: ListOf(NativeType),
weights: (NoneType, ListOf(float)) = None,
aggregate=ZAggregate.SUM,
) -> int:
"""Intersect multiple sorted sets and store the resulting sorted set in a new
key
"""
return self._zstore(tr, b"zinterstore", destination, keys, weights, aggregate)
def _zstore(self, tr, command, destination, keys, weights, aggregate):
"""Common part for zunionstore and zinterstore
"""
numkeys = len(keys)
if weights is None:
weights = [1] * numkeys
return self._query(
tr,
*[command, self.encode_from_native(destination), self._encode_int(numkeys)]
+ list(map(self.encode_from_native, keys))
+ [b"weights"]
+ list(map(self._encode_float, weights))
+ [b"aggregate"]
+ [
{
ZAggregate.SUM: b"SUM",
ZAggregate.MIN: b"MIN",
ZAggregate.MAX: b"MAX",
}[aggregate]
],
)
@_query_command
def zcard(self, tr, key: NativeType) -> int:
"""Get the number of members in a sorted set
"""
return self._query(tr, b"zcard", self.encode_from_native(key))
@_query_command
def zrank(self, tr, key: NativeType, member: NativeType) -> (int, NoneType):
"""Determine the index of a member in a sorted set
"""
return self._query(
tr, b"zrank", self.encode_from_native(key), self.encode_from_native(member)
)
@_query_command
def zrevrank(self, tr, key: NativeType, member: NativeType) -> (int, NoneType):
"""Determine the index of a member in a sorted set, with scores ordered from
high to low
"""
return self._query(
tr,
b"zrevrank",
self.encode_from_native(key),
self.encode_from_native(member),
)
@_query_command
def zincrby(
self,
tr,
key: NativeType,
increment: float,
member: NativeType,
only_if_exists=False,
) -> (float, NoneType):
"""Increment the score of a member in a sorted set
"""
if only_if_exists:
return self._query(
tr,
b"zadd",
self.encode_from_native(key),
b"xx",
b"incr",
self._encode_float(increment),
self.encode_from_native(member),
)
else:
return self._query(
tr,
b"zincrby",
self.encode_from_native(key),
self._encode_float(increment),
self.encode_from_native(member),
)
@_query_command
def zrem(self, tr, key: NativeType, members: ListOf(NativeType)) -> int:
"""Remove one or more members from a sorted set
"""
return self._query(
tr,
b"zrem",
self.encode_from_native(key),
*map(self.encode_from_native, members),
)
# Hashes
@_query_command
def hset(self, tr, key: NativeType, field: NativeType, value: NativeType) -> int:
"""Set the string value of a hash field
"""
return self._query(
tr,
b"hset",
self.encode_from_native(key),
self.encode_from_native(field),
self.encode_from_native(value),
)
@_query_command
def hmset(self, tr, key: NativeType, values: dict) -> StatusReply:
"""Set multiple hash fields to multiple values
"""
data = []
for k, v in values.items():
assert isinstance(k, self.native_type)
assert isinstance(v, self.native_type)
data.append(self.encode_from_native(k))
data.append(self.encode_from_native(v))
return self._query(tr, b"hmset", self.encode_from_native(key), *data)
@_query_command
def hsetnx(self, tr, key: NativeType, field: NativeType, value: NativeType) -> int:
"""Set the value of a hash field, only if the field does not exist
"""
return self._query(
tr,
b"hsetnx",
self.encode_from_native(key),
self.encode_from_native(field),
self.encode_from_native(value),
)
@_query_command
def hdel(self, tr, key: NativeType, fields: ListOf(NativeType)) -> int:
"""Delete one or more hash fields
"""
return self._query(
tr,
b"hdel",
self.encode_from_native(key),
*map(self.encode_from_native, fields),
)
@_query_command
def hget(self, tr, key: NativeType, field: NativeType) -> (NativeType, NoneType):
"""Get the value of a hash field
"""
return self._query(
tr, b"hget", self.encode_from_native(key), self.encode_from_native(field)
)
@_query_command
def hexists(self, tr, key: NativeType, field: NativeType) -> bool:
"""Return if field is an existing field in the hash stored at key
"""
return self._query(
tr, b"hexists", self.encode_from_native(key), self.encode_from_native(field)
)
@_query_command
def hkeys(self, tr, key: NativeType) -> SetReply:
"""Get all the keys in a hash. Returns a set.
"""
return self._query(tr, b"hkeys", self.encode_from_native(key))
@_query_command
def hvals(self, tr, key: NativeType) -> ListReply:
"""Get all the values in a hash. Returns a list.
"""
return self._query(tr, b"hvals", self.encode_from_native(key))
@_query_command
def hlen(self, tr, key: NativeType) -> int:
"""Return the number of fields contained in the hash stored at key
"""
return self._query(tr, b"hlen", self.encode_from_native(key))
@_query_command
def hgetall(self, tr, key: NativeType) -> DictReply:
"""Get the value of a hash field
"""
return self._query(tr, b"hgetall", self.encode_from_native(key))
@_query_command
def hmget(self, tr, key: NativeType, fields: ListOf(NativeType)) -> ListReply:
"""Get the values of all the given hash fields
"""
return self._query(
tr,
b"hmget",
self.encode_from_native(key),
*map(self.encode_from_native, fields),
)
@_query_command
def hincrby(self, tr, key: NativeType, field: NativeType, increment) -> int:
"""Increment the integer value of a hash field by the given number.
Return the value at field after the increment operation.
"""
assert isinstance(increment, int)
return self._query(
tr,
b"hincrby",
self.encode_from_native(key),
self.encode_from_native(field),
self._encode_int(increment),
)
@_query_command
def hincrbyfloat(
self, tr, key: NativeType, field: NativeType, increment: (int, float)
) -> float:
"""Increment the float value of a hash field by the given amount.
Return the value at field after the increment operation.
"""
return self._query(
tr,
b"hincrbyfloat",
self.encode_from_native(key),
self.encode_from_native(field),
self._encode_float(increment),
)
# Pubsub
# (subscribe, unsubscribe, etc... should be called through the Subscription class.)
@_command
async def start_subscribe(self, tr, *a) -> "Subscription":
"""Start a pubsub listener.
::
# Create subscription
subscription = await protocol.start_subscribe()
await subscription.subscribe(['key'])
await subscription.psubscribe(['pattern*'])
while True:
result = await subscription.next_published()
print(result)
:returns: :class:`~asyncio_redis.Subscription`
"""
if self.in_use:
raise Error("Cannot start pubsub listener when a protocol is in use.")
subscription = Subscription(self)
self._in_pubsub = True
self._subscription = subscription
return subscription
@_command
def _subscribe(self, tr, channels: ListOf(NativeType)) -> NoneType:
"""Listen for messages published to the given channels
"""
self._pubsub_channels |= set(channels)
return self._pubsub_method("subscribe", channels)
@_command
def _unsubscribe(self, tr, channels: ListOf(NativeType)) -> NoneType:
"""Stop listening for messages posted to the given channels
"""
self._pubsub_channels -= set(channels)
return self._pubsub_method("unsubscribe", channels)
@_command
def _psubscribe(self, tr, patterns: ListOf(NativeType)) -> NoneType:
"""Listen for messages published to channels matching the given patterns
"""
self._pubsub_patterns |= set(patterns)
return self._pubsub_method("psubscribe", patterns)
@_command
def _punsubscribe(
self, tr, patterns: ListOf(NativeType)
) -> NoneType: # XXX: unittest
"""Stop listening for messages posted to channels matching the given patterns
"""
self._pubsub_patterns -= set(patterns)
return self._pubsub_method("punsubscribe", patterns)
async def _pubsub_method(self, method, params):
if not self._in_pubsub:
raise Error("Cannot call pubsub methods without calling start_subscribe")
# Send
self._send_command(
[method.encode("ascii")] + list(map(self.encode_from_native, params))
)
# Note that we can't use `self._query` here. The reason is that one
# subscribe/unsubscribe command returns a separate answer for every
# parameter. It doesn't fit in the same model of all the other queries
# where one query puts a Future on the queue that is replied with the
# incoming answer.
# Redis returns something like [ 'subscribe', 'channel_name', 1] for
# each parameter, but we can safely ignore those replies that.
@_query_command
def publish(self, tr, channel: NativeType, message: NativeType) -> int:
"""Post a message to a channel.
Return the number of clients that received this message.
"""
return self._query(
tr,
b"publish",
self.encode_from_native(channel),
self.encode_from_native(message),
)
@_query_command
def pubsub_channels(self, tr, pattern: (NativeType, NoneType) = None) -> ListReply:
"""List the currently active channels. An active channel is a Pub/Sub channel
with one ore more subscribers (not including clients subscribed to patterns)
"""
return self._query(
tr,
b"pubsub",
b"channels",
(self.encode_from_native(pattern) if pattern else b"*"),
)
@_query_command
def pubsub_numsub(self, tr, channels: ListOf(NativeType)) -> DictReply:
"""Return the number of subscribers (not counting clients subscribed to
patterns) for the specified channels
"""
return self._query(
tr, b"pubsub", b"numsub", *[self.encode_from_native(c) for c in channels]
)
@_query_command
def pubsub_numpat(self, tr) -> int:
"""Return the number of subscriptions to patterns (that are performed using the
PSUBSCRIBE command). Note that this is not just the count of clients subscribed
to patterns but the total number of patterns all the clients are subscribed to.
"""
return self._query(tr, b"pubsub", b"numpat")
# Server
@_query_command
def ping(self, tr) -> StatusReply:
"""Ping the server (Returns PONG)
"""
return self._query(tr, b"ping")
@_query_command
def echo(self, tr, string: NativeType) -> NativeType:
"""Echo the given string
"""
return self._query(tr, b"echo", self.encode_from_native(string))
@_query_command
def save(self, tr) -> StatusReply:
"""Synchronously save the dataset to disk
"""
return self._query(tr, b"save")
@_query_command
def bgsave(self, tr) -> StatusReply:
"""Asynchronously save the dataset to disk
"""
return self._query(tr, b"bgsave")
@_query_command
def bgrewriteaof(self, tr) -> StatusReply:
"""Asynchronously rewrite the append-only file
"""
return self._query(tr, b"bgrewriteaof")
@_query_command
def lastsave(self, tr) -> int:
"""Get the UNIX time stamp of the last successful save to disk
"""
return self._query(tr, b"lastsave")
@_query_command
def dbsize(self, tr) -> int:
"""Return the number of keys in the currently-selected database
"""
return self._query(tr, b"dbsize")
@_query_command
def flushall(self, tr) -> StatusReply:
"""Remove all keys from all databases
"""
return self._query(tr, b"flushall")
@_query_command
def flushdb(self, tr) -> StatusReply:
"""Delete all the keys of the currently selected DB. This command never fails.
"""
return self._query(tr, b"flushdb")
@_query_command
def type(self, tr, key: NativeType) -> StatusReply:
"""Determine the type stored at key
"""
return self._query(tr, b"type", self.encode_from_native(key))
@_query_command
def config_set(self, tr, parameter: str, value: str) -> StatusReply:
"""Set a configuration parameter to the given value
"""
return self._query(
tr,
b"config",
b"set",
self.encode_from_native(parameter),
self.encode_from_native(value),
)
@_query_command
def config_get(self, tr, parameter: str) -> ConfigPairReply:
"""Get the value of a configuration parameter
"""
return self._query(tr, b"config", b"get", self.encode_from_native(parameter))
@_query_command
def config_rewrite(self, tr) -> StatusReply:
"""Rewrite the configuration file with the in memory configuration
"""
return self._query(tr, b"config", b"rewrite")
@_query_command
def config_resetstat(self, tr) -> StatusReply:
"""Reset the stats returned by INFO
"""
return self._query(tr, b"config", b"resetstat")
@_query_command
def info(self, tr, section: (NativeType, NoneType) = None) -> InfoReply:
"""Get information and statistics about the server
"""
if section is None:
return self._query(tr, b"info")
else:
return self._query(tr, b"info", self.encode_from_native(section))
@_query_command
def shutdown(self, tr, save=False) -> StatusReply:
"""Synchronously save the dataset to disk and then shut down the server
"""
return self._query(tr, b"shutdown", (b"save" if save else b"nosave"))
@_query_command
def client_getname(self, tr) -> NativeType:
"""Get the current connection name
"""
return self._query(tr, b"client", b"getname")
@_query_command
def client_setname(self, tr, name) -> StatusReply:
"""Set the current connection name
"""
return self._query(tr, b"client", b"setname", self.encode_from_native(name))
@_query_command
def client_list(self, tr) -> ClientListReply:
"""Get the list of client connections
"""
return self._query(tr, b"client", b"list")
@_query_command
def client_kill(self, tr, address: str) -> StatusReply:
"""Kill the connection of a client.
`address` should be an "ip:port" string.
"""
return self._query(tr, b"client", b"kill", address.encode("utf-8"))
# LUA scripting
@_command
async def register_script(self, tr, script: str) -> "Script":
"""Register a LUA script.
::
script = await protocol.register_script(lua_code)
result = await script.run(keys=[...], args=[...])
"""
# The register_script APi was made compatible with the redis.py library:
# https://github.com/andymccurdy/redis-py
sha = await self.script_load(tr, script)
return Script(sha, script, lambda: self.evalsha)
@_query_command
def script_exists(self, tr, shas: ListOf(str)) -> ListOf(bool):
"""Check existence of scripts in the script cache
"""
return self._query(
tr, b"script", b"exists", *[sha.encode("ascii") for sha in shas]
)
@_query_command
def script_flush(self, tr) -> StatusReply:
"""Remove all the scripts from the script cache
"""
return self._query(tr, b"script", b"flush")
@_query_command
async def script_kill(self, tr) -> StatusReply:
"""Kill the script currently in execution. This raises
:class:`~asyncio_redis.NoRunningScriptError` when there are no scrips running.
"""
try:
return await self._query(tr, b"script", b"kill")
except ErrorReply as e:
if "NOTBUSY" in e.args[0]:
raise NoRunningScriptError
else:
raise
@_query_command
async def evalsha(
self,
tr,
sha: str,
keys: (ListOf(NativeType), NoneType) = None,
args: (ListOf(NativeType), NoneType) = None,
) -> EvalScriptReply:
"""Evaluate a script cached on the server side by its SHA1 digest.
Scripts are cached on the server side using the SCRIPT LOAD command.
The return type/value depends on the script.
This will raise a :class:`~asyncio_redis.ScriptKilledError`
exception if the script was killed.
"""
if not keys:
keys = []
if not args:
args = []
try:
result = await self._query(
tr,
b"evalsha",
sha.encode("ascii"),
self._encode_int(len(keys)),
*map(self.encode_from_native, keys + args),
)
return result
except ErrorReply:
raise ScriptKilledError
@_query_command
def script_load(self, tr, script: str) -> str:
"""Load script and return its sha1
"""
return self._query(tr, b"script", b"load", script.encode("utf-8"))
# Scanning
@_command
async def scan(self, tr, match: (NativeType, NoneType) = None) -> Cursor:
"""Walk through the keys space. You can either fetch the items one by one
or in bulk.
::
cursor = await protocol.scan(match='*')
while True:
item = await cursor.fetchone()
if item is None:
break
else:
print(item)
::
cursor = await protocol.scan(match='*')
items = await cursor.fetchall()
It's possible to alter the COUNT-parameter, by assigning a value to
``cursor.count``, before calling ``fetchone`` or ``fetchall``. For
instance:
::
cursor.count = 100
Also see: :func:`~asyncio_redis.RedisProtocol.sscan`,
:func:`~asyncio_redis.RedisProtocol.hscan` and
:func:`~asyncio_redis.RedisProtocol.zscan`
Redis reference: http://redis.io/commands/scan
"""
def scanfunc(cursor, count):
return self._scan(tr, cursor, match, count)
return Cursor(name="scan(match=%r)" % match, scanfunc=scanfunc)
@_query_command
def _scan(
self, tr, cursor: int, match: (NativeType, NoneType), count: int
) -> _ScanPart:
match = b"*" if match is None else self.encode_from_native(match)
return self._query(
tr,
b"scan",
self._encode_int(cursor),
b"match",
match,
b"count",
self._encode_int(count),
)
@_command
async def sscan(
self, tr, key: NativeType, match: (NativeType, NoneType) = None
) -> SetCursor:
"""Incrementally iterate set elements
Also see: :func:`~asyncio_redis.RedisProtocol.scan`
"""
name = "sscan(key=%r match=%r)" % (key, match)
def scan(cursor, count):
return self._do_scan(tr, b"sscan", key, cursor, match, count)
return SetCursor(name=name, scanfunc=scan)
@_command
async def hscan(
self, tr, key: NativeType, match: (NativeType, NoneType) = None
) -> DictCursor:
"""Incrementally iterate hash fields and associated values
Also see: :func:`~asyncio_redis.RedisProtocol.scan`
"""
name = "hscan(key=%r match=%r)" % (key, match)
def scan(cursor, count):
return self._do_scan(tr, b"hscan", key, cursor, match, count)
return DictCursor(name=name, scanfunc=scan)
@_command
async def zscan(
self, tr, key: NativeType, match: (NativeType, NoneType) = None
) -> DictCursor:
"""Incrementally iterate sorted sets elements and associated scores
Also see: :func:`~asyncio_redis.RedisProtocol.scan`
"""
name = "zscan(key=%r match=%r)" % (key, match)
def scan(cursor, count):
return self._do_scan(b"zscan", key, cursor, match, count)
return ZCursor(name=name, scanfunc=scan)
@_query_command
def _do_scan(
self,
tr,
verb: bytes,
key: NativeType,
cursor: int,
match: (NativeType, NoneType),
count: int,
) -> _ScanPart:
match = b"*" if match is None else self.encode_from_native(match)
return self._query(
tr,
verb,
self.encode_from_native(key),
self._encode_int(cursor),
b"match",
match,
b"count",
self._encode_int(count),
)
# Transaction
@_command
async def watch(self, tr, keys: ListOf(NativeType)) -> NoneType:
"""Watch keys.
::
# Watch keys for concurrent updates
await protocol.watch(['key', 'other_key'])
value = await protocol.get('key')
another_value = await protocol.get('another_key')
transaction = await protocol.multi()
f1 = await transaction.set('key', another_value)
f2 = await transaction.set('another_key', value)
# Commit transaction
await transaction.exec()
# Retrieve results
await f1
await f2
"""
return await self._watch(tr, keys)
async def _watch(self, tr, keys: ListOf(NativeType)) -> NoneType:
result = await self._query(
tr, b"watch", *map(self.encode_from_native, keys), _bypass=True
)
assert result == b"OK"
@_command
async def multi(
self, tr, watch: (ListOf(NativeType), NoneType) = None
) -> "Transaction":
"""Start of transaction.
::
transaction = await protocol.multi()
# Run commands in transaction
f1 = await transaction.set('key', 'value')
f2 = await transaction.set('another_key', 'another_value')
# Commit transaction
await transaction.exec()
# Retrieve results (you can also use asyncio.tasks.gather)
result1 = await f1
result2 = await f2
:returns: A :class:`asyncio_redis.Transaction` instance.
"""
# Create transaction object.
if tr != _NoTransaction:
raise Error("Multi calls can not be nested.")
else:
await self._transaction_lock.acquire()
tr = Transaction(self)
self._transaction = tr
# Call watch
if watch is not None:
await self._watch(tr, watch)
# Call multi
result = await self._query(tr, b"multi", _bypass=True)
assert result == b"OK"
self._transaction_response_queue = deque()
return tr
async def _exec(self, tr):
"""Execute all commands issued after MULTI
"""
if not self._transaction or self._transaction != tr:
raise Error("Not in transaction")
try:
futures_and_postprocessors = self._transaction_response_queue
self._transaction_response_queue = None
# Get transaction answers.
multi_bulk_reply = await self._query(tr, b"exec", _bypass=True)
if multi_bulk_reply is None:
# We get None when a transaction failed.
raise TransactionError("Transaction failed.")
else:
assert isinstance(multi_bulk_reply, MultiBulkReply)
for f in multi_bulk_reply.iter_raw():
answer = await f
f2, call = futures_and_postprocessors.popleft()
if isinstance(answer, Exception):
f2.set_exception(answer)
else:
if call:
self._pipelined_calls.remove(call)
f2.set_result(answer)
finally:
self._transaction_response_queue = deque()
self._transaction = None
self._transaction_lock.release()
async def _discard(self, tr):
"""Discard all commands issued after MULTI
"""
if not self._transaction or self._transaction != tr:
raise Error("Not in transaction")
try:
result = await self._query(tr, b"discard", _bypass=True)
assert result == b"OK"
finally:
self._transaction_response_queue = deque()
self._transaction = None
self._transaction_lock.release()
async def _unwatch(self, tr):
"""Forget about all watched keys
"""
if not self._transaction or self._transaction != tr:
raise Error("Not in transaction")
result = await self._query(tr, b"unwatch") # XXX: should be _bypass???
assert result == b"OK"
class Script:
"""Lua script
"""
def __init__(self, sha, code, get_evalsha_func):
self.sha = sha
self.code = code
self.get_evalsha_func = get_evalsha_func
def run(self, keys=[], args=[]):
"""Return a coroutine that executes the script.
::
script_reply = await script.run(keys=[], args=[])
# If the LUA script returns something, retrieve the return value
result = await script_reply.return_value()
This will raise a :class:`~asyncio_redis.ScriptKilledError`
exception if the script was killed.
"""
return self.get_evalsha_func()(self.sha, keys, args)
class Transaction:
"""
Transaction context. This is a proxy to a :class:`.RedisProtocol` instance.
Every redis command called on this object will run inside the transaction.
The transaction can be finished by calling either ``discard`` or ``exec``.
More info: http://redis.io/topics/transactions
"""
def __init__(self, protocol):
self._protocol = protocol
def __getattr__(self, name):
"""Proxy to a protocol
"""
# Only proxy commands.
if name not in _all_commands:
raise AttributeError(name)
method = getattr(self._protocol, name)
# Wrap the method into something that passes the transaction object as
# first argument.
@wraps(method)
def wrapper(*a, **kw):
if self._protocol._transaction != self:
raise Error("Transaction already finished or invalid.")
return method(self, *a, **kw)
return wrapper
def discard(self):
"""Discard all commands issued after MULTI
"""
return self._protocol._discard(self)
def exec(self):
"""Execute transaction.
This can raise a :class:`~asyncio_redis.TransactionError`
when the transaction fails.
"""
return self._protocol._exec(self)
def unwatch(self): # XXX: test
"""Forget about all watched keys
"""
return self._protocol._unwatch(self)
class Subscription:
"""Pubsub subscription
"""
def __init__(self, protocol):
self.protocol = protocol
self._messages_queue = asyncio.Queue() # Pubsub queue
@wraps(RedisProtocol._subscribe)
def subscribe(self, channels):
return self.protocol._subscribe(self, channels)
@wraps(RedisProtocol._unsubscribe)
def unsubscribe(self, channels):
return self.protocol._unsubscribe(self, channels)
@wraps(RedisProtocol._psubscribe)
def psubscribe(self, patterns):
return self.protocol._psubscribe(self, patterns)
@wraps(RedisProtocol._punsubscribe)
def punsubscribe(self, patterns):
return self.protocol._punsubscribe(self, patterns)
async def next_published(self):
"""Coroutine which waits for next pubsub message to be received and returns it.
:returns: instance of :class:`PubSubReply <asyncio_redis.replies.PubSubReply>`
"""
return await self._messages_queue.get()
class HiRedisProtocol(RedisProtocol, metaclass=_RedisProtocolMeta):
"""Protocol implementation that uses the `hiredis` library for parsing the
incoming data. This will be faster in many cases, but not necessarily
always.
It does not (yet) support streaming of multibulk replies, which means that
you won't see the first item of a multi bulk reply, before the whole
response has been parsed.
"""
def __init__(
self,
*,
password=None,
db=0,
encoder=None,
connection_lost_callback=None,
enable_typechecking=True,
):
super().__init__(
password=password,
db=db,
encoder=encoder,
connection_lost_callback=connection_lost_callback,
enable_typechecking=enable_typechecking,
)
self._hiredis = None
assert (
hiredis
), "`hiredis` libary not available. Please don't use HiRedisProtocol."
def connection_made(self, transport):
super().connection_made(transport)
self._hiredis = hiredis.Reader()
def data_received(self, data):
# Move received data to hiredis parser
self._hiredis.feed(data)
while True:
item = self._hiredis.gets()
if item is not False:
self._process_hiredis_item(item, self._push_answer)
else:
break
def _process_hiredis_item(self, item, cb):
if isinstance(item, (bytes, int)):
cb(item)
elif isinstance(item, list):
reply = MultiBulkReply(self, len(item))
for i in item:
self._process_hiredis_item(i, reply._feed_received)
cb(reply)
elif isinstance(item, hiredis.ReplyError):
cb(ErrorReply(item.args[0]))
elif isinstance(item, NoneType):
cb(item)
async def _reader_coroutine(self):
# We don't need this one.
return
Sindbad File Manager Version 1.0, Coded By Sindbad EG ~ The Terrorists