# -*- coding: utf-8 -*-
import asyncio
import copy
from functools import partial
import logging
import typing
from typing import Dict, Optional, Union
import aio_pika
import shortuuid
import kiwipy
from . import defaults, messages, tasks, utils
__all__ = 'RmqCommunicator', 'async_connect'
_LOGGER = logging.getLogger(__name__)
# The exchange properties use by the publisher and subscriber. These have to match
# which is why they're declare her
EXCHANGE_PROPERTIES = {'type': aio_pika.ExchangeType.TOPIC}
class RmqPublisher(messages.BasePublisherWithReplyQueue):
"""
Publisher for sending a range of message types over RMQ
"""
DEFAULT_EXCHANGE_PARAMS = EXCHANGE_PROPERTIES
async def rpc_send(self, recipient_id, msg):
routing_key = f'{defaults.RPC_TOPIC}.{recipient_id}'
_LOGGER.debug(
'Sending RPC with routing key %r to RMQ queue %r: %r',
routing_key,
self._reply_queue.name,
msg,
)
message = aio_pika.Message(body=self._encode(msg), reply_to=self._reply_queue.name)
published, response_future = await self.publish_expect_response(
message, routing_key=routing_key, mandatory=True
)
assert published, 'The message was not published to the exchanges'
return response_future
async def broadcast_send(self, msg, sender=None, subject=None, correlation_id=None):
message_dict = messages.BroadcastMessage.create(
body=msg,
sender=sender,
subject=subject,
correlation_id=correlation_id,
)
_LOGGER.debug(
'Sending broadcast with routing key %r to RMQ via exchange %r: %r',
defaults.BROADCAST_TOPIC,
self._exchange_name,
message_dict,
)
message = aio_pika.Message(
body=self._encode(message_dict),
delivery_mode=aio_pika.DeliveryMode.NOT_PERSISTENT,
)
# Send as mandatory=False because we don't expect the message to be routable to anyone
return await self.publish(message, routing_key=defaults.BROADCAST_TOPIC, mandatory=False)
class RmqSubscriber:
"""
Subscriber for receiving a range of messages over RMQ
"""
# pylint: disable=too-many-instance-attributes
def __init__(
self,
connection,
message_exchange=defaults.MESSAGE_EXCHANGE,
queue_expires=defaults.QUEUE_EXPIRES,
decoder=defaults.DECODER,
encoder=defaults.ENCODER,
testing_mode=False
):
# pylint: disable=too-many-arguments
"""
Subscribes and listens for process control messages and acts on them
by calling the corresponding methods of the process manager.
:param connection: The tokpia connection
:type connection: :class:`aio_pika.Connection`
:param message_exchange: The name of the exchange to use
:param queue_expires: the expiry time for standard queues in milliseconds. This is the time after which, if
there are no subscribers, a queue will automatically be deleted by RabbitMQ.
:type queue_expires: int
:param encoder: The encoder to call for encoding a message
:param decoder: The decoder to call for decoding a message
:param testing_mode: Run in testing mode: all queues and exchanges will be temporary
"""
super().__init__()
self._connection = connection
self._channel = None # type: typing.Optional[aio_pika.Channel]
self._exchange = None # type: typing.Optional[aio_pika.Exchange]
self._exchange_name = message_exchange
self._decode = decoder
self._testing_mode = testing_mode
self._response_encode = encoder
self._broadcast_queue_arguments = {'x-message-ttl': defaults.MESSAGE_TTL}
self._rmq_queue_arguments = {'x-message-ttl': defaults.MESSAGE_TTL}
if queue_expires:
self._rmq_queue_arguments['x-expires'] = queue_expires
self._rpc_subscribers = {}
self._broadcast_subscribers = {}
self._broadcast_queue = None # type: typing.Optional[aio_pika.Queue]
self._broadcast_consumer_tag = None
async def add_rpc_subscriber(self, subscriber, identifier=None):
# Create an RPC queue
rpc_queue = await self._channel.declare_queue(exclusive=True, arguments=self._rmq_queue_arguments)
try:
identifier = await rpc_queue.consume(partial(self._on_rpc, subscriber), consumer_tag=identifier)
except aio_pika.exceptions.DuplicateConsumerTag as exception:
raise kiwipy.DuplicateSubscriberIdentifier(f"RPC identifier '{identifier}'") from exception
else:
await rpc_queue.bind(self._exchange, routing_key=f'{defaults.RPC_TOPIC}.{identifier}')
# Save the queue so we can cancel and unbind later
self._rpc_subscribers[identifier] = rpc_queue
return identifier
async def remove_rpc_subscriber(self, identifier):
try:
rpc_queue = self._rpc_subscribers.pop(identifier)
except KeyError as exception:
raise ValueError(f"Unknown subscriber '{identifier}'") from exception
else:
await rpc_queue.cancel(identifier)
await rpc_queue.unbind(self._exchange, routing_key=f'{defaults.RPC_TOPIC}.{identifier}')
async def add_broadcast_subscriber(self, subscriber, identifier=None):
identifier = identifier or shortuuid.uuid()
if identifier in self._broadcast_subscribers:
raise kiwipy.DuplicateSubscriberIdentifier(f"Broadcast identifier '{identifier}'")
self._broadcast_subscribers[identifier] = subscriber
if self._broadcast_consumer_tag is None:
# Consume on the broadcast queue
self._broadcast_consumer_tag = await self._broadcast_queue.consume(self._on_broadcast)
return identifier
async def remove_broadcast_subscriber(self, identifier):
try:
del self._broadcast_subscribers[identifier]
except KeyError as exception:
raise ValueError(f"Broadcast subscriber '{identifier}' unknown") from exception
if not self._broadcast_subscribers:
await self._broadcast_queue.cancel(self._broadcast_consumer_tag)
self._broadcast_consumer_tag = None
def channel(self):
return self._channel
async def connect(self):
"""Get a channel and set up all the exchanges/queues we need"""
if self._channel:
# Already connected
return
exchange_params = copy.copy(EXCHANGE_PROPERTIES)
if self._testing_mode:
exchange_params.setdefault('auto_delete', self._testing_mode)
self._channel = await self._connection.channel()
self._exchange = await self._channel.declare_exchange(name=self._exchange_name, **exchange_params)
await self._create_broadcast_queue()
async def _create_broadcast_queue(self):
"""
Create and bind the broadcast queue
One is used for all broadcasts on this exchange
"""
# Create a new such that we can see this is the broadcast queue
name = f'broadcast-{shortuuid.uuid()}'
self._broadcast_queue = await self._channel.declare_queue(
name=name, exclusive=True, arguments=self._broadcast_queue_arguments
)
await self._broadcast_queue.bind(self._exchange, routing_key=defaults.BROADCAST_TOPIC)
async def disconnect(self):
await self._channel.close()
self._exchange = None
self._channel = None
async def _on_rpc(self, subscriber, message):
"""
:param subscriber: the subscriber function or coroutine that will get the RPC message
:param message: the RMQ message
:type message: :class:`aio_pika.IncomingMessage`
"""
async with message.process(ignore_processed=True):
# Tell the sender that we've dealt with it
await message.ack()
msg = self._decode(message.body)
try:
receiver = utils.ensure_coroutine(subscriber)
result = await receiver(self, msg)
except Exception as exc: # pylint: disable=broad-except
# We had an exception in calling the receiver
await self._send_response(message.reply_to, message.correlation_id, utils.exception_response(exc))
else:
if asyncio.isfuture(result):
await self._send_future_response(result, message.reply_to, message.correlation_id)
else:
# All good, send the response out
await self._send_response(message.reply_to, message.correlation_id, utils.result_response(result))
async def _on_broadcast(self, message):
async with message.process():
msg = self._decode(message.body)
for receiver in self._broadcast_subscribers.values():
try:
receiver = utils.ensure_coroutine(receiver)
await receiver(
self, msg[messages.BroadcastMessage.BODY], msg[messages.BroadcastMessage.SENDER],
msg[messages.BroadcastMessage.SUBJECT], msg[messages.BroadcastMessage.CORRELATION_ID]
)
except Exception: # pylint: disable=broad-except
_LOGGER.exception('Exception in broadcast receiver')
async def _send_future_response(self, future, reply_to, correlation_id):
"""
The RPC call returned a future which means we need to send a pending response
and send a further message when the future resolves. If it resolves to another future
we should send out a further pending response and so on.
:param future: the future from the RPC call
:type future: :class:`asyncio.Future`
:param reply_to: the recipient
:param correlation_id: the correlation id
"""
try:
# Keep looping in case we're in a situation where a future resolves to a future etc.
while asyncio.isfuture(future):
# Send out a message saying that we're waiting for a future to complete
await self._send_response(reply_to, correlation_id, utils.pending_response())
future = await future
except kiwipy.CancelledError as exc:
# Send out a cancelled response
await self._send_response(reply_to, correlation_id, utils.cancelled_response(str(exc)))
except Exception as exc: # pylint: disable=broad-except
# Send out an exception response
await self._send_response(reply_to, correlation_id, utils.exception_response(exc))
else:
# We have a final result so send that as the response
await self._send_response(reply_to, correlation_id, utils.result_response(future))
async def _send_response(self, reply_to, correlation_id, response):
assert reply_to, 'Must provide an identifier for the recipient'
message = aio_pika.Message(body=self._response_encode(response), correlation_id=correlation_id)
result = await self._exchange.publish(message, routing_key=reply_to)
return result
[docs]class RmqCommunicator:
"""
An asynchronous communicator that relies on aio_pika to make a connection to a RabbitMQ server
and uses an asyncio event loop for scheduling coroutines and callbacks.
"""
# pylint: disable=too-many-instance-attributes
_connection = None
_message_subscriber = None
_message_publisher = None
_default_task_queue = None # type: Optional[tasks.RmqTaskQueue]
def __init__(
self,
connection: aio_pika.Connection,
# Messages
message_exchange: str = defaults.MESSAGE_EXCHANGE,
queue_expires: int = defaults.QUEUE_EXPIRES,
# Tasks
task_exchange: str = defaults.TASK_EXCHANGE,
task_queue: str = defaults.TASK_QUEUE,
task_prefetch_size=defaults.TASK_PREFETCH_SIZE,
task_prefetch_count=defaults.TASK_PREFETCH_COUNT,
encoder=defaults.ENCODER,
decoder=defaults.DECODER,
testing_mode=False
):
# pylint: disable=too-many-arguments
"""Create a new asynchronous communicator.
.. note: this communicator takes ownership of the connection and, therefore, it should not be shared as when
this communicator disconnects it will also hang up the connection.
:param connection: An aio_pika connection, doesn't need to be connected
:param message_exchange: The name of the RMQ message exchange to use
:param queue_expires: the expiry time for standard queues in milliseconds. This is the time after which, if
there are no subscribers, a queue will automatically be deleted by RabbitMQ.
:param task_exchange: The name of the RMQ task exchange to use
:param task_queue: The name of the task queue to use
:param task_prefetch_count: the number of tasks this communicator can fetch simultaneously
:param task_prefetch_size: the total size of the messages that the default queue can fetch simultaneously
:param encoder: The encoder to call for encoding a message
:param decoder: The decoder to call for decoding a message
:param testing_mode: Run in testing mode: all queues and exchanges will be temporary
"""
super().__init__()
self._connection = connection
self._loop = connection.loop
# Save some of these settings for later
self._message_exchange = message_exchange
self._queue_expires = queue_expires
# Default tasks queue
self._task_exchange = task_exchange
self._task_queue = task_queue
self._task_prefetch_size = task_prefetch_size
self._task_prefetch_count = task_prefetch_count
self._task_queues = []
self._decoder = decoder
self._encoder = encoder
self._testing_mode = testing_mode
async def __aenter__(self):
await self.connect()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.disconnect()
def __str__(self):
return f'RMQCommunicator({self._connection})'
@property
def server_properties(self) -> Dict:
"""
A dictionary containing server properties as returned by the RMQ server at connection time.
The details are defined by the RMQ standard and can be found here:
https://www.rabbitmq.com/amqp-0-9-1-reference.html#connection.start.server-properties
The protocol states that this dictionary SHOULD contain at least:
'host' - specifying the server host name or address
'product' - giving the name of the server product
'version' - giving the name of the server version
'platform' - giving the name of the operating system
'copyright' - if appropriate, and,
'information' - giving other general information
.. note:: In testing it seems like 'host' is not always returned. Host information may be found in
'cluster_name' but clients shouldn't rely on this.
:return: the server properties dictionary
"""
if self._connection is None:
return {}
return self._connection.transport.connection.server_properties
@property
def loop(self):
"""Get the event loop instance driving this communicator connection."""
return self._connection.loop
[docs] def add_close_callback(self, callback: aio_pika.abc.ConnectionCloseCallback, weak: bool = False) -> None:
"""Add a callable to be called each time (after) the connection is closed.
:param weak: If True, the callback will be added to a `WeakSet`
"""
self._connection.close_callbacks.add(callback, weak)
[docs] async def get_default_task_queue(self) -> tasks.RmqTaskQueue:
"""Get a default task queue.
If one doesn't exist it will be created as part of this call.
"""
self._ensure_connected()
if self._default_task_queue is None:
task_queue = tasks.RmqTaskQueue(
self._connection,
exchange_name=self._task_exchange,
queue_name=self._task_queue,
decoder=self._decoder,
encoder=self._encoder,
prefetch_size=self._task_prefetch_size,
prefetch_count=self._task_prefetch_count,
testing_mode=self._testing_mode
)
await task_queue.connect()
self._default_task_queue = task_queue
return self._default_task_queue
[docs] async def get_message_subscriber(self) -> RmqSubscriber:
"""Get the message subscriber.
If one doesn't exist it will be created as part of this call.
"""
self._ensure_connected()
if self._message_subscriber is None:
subscriber = RmqSubscriber(
self._connection,
message_exchange=self._message_exchange,
queue_expires=self._queue_expires,
encoder=self._encoder,
decoder=self._decoder,
testing_mode=self._testing_mode
)
await subscriber.connect()
self._message_subscriber = subscriber
return self._message_subscriber
[docs] async def get_message_publisher(self) -> RmqPublisher:
"""Get a message publisher.
If one doesn't exist it will be created as part of this call.
"""
self._ensure_connected()
if self._message_publisher is None:
publisher = RmqPublisher(
self._connection,
exchange_name=self._message_exchange,
encoder=self._encoder,
decoder=self._decoder,
testing_mode=self._testing_mode
)
await publisher.connect()
self._message_publisher = publisher
return self._message_publisher
[docs] def connected(self) -> bool:
return self._connection is not None and not self._connection.is_closed
[docs] async def connect(self):
"""Establish a connection if not already connected."""
if not self.connected():
await self._connection.connect()
[docs] async def disconnect(self):
"""Disconnect from the connection if connected."""
if not self.connected():
return
if self._message_publisher is not None:
await self._message_publisher.disconnect()
self._message_publisher = None
if self._message_subscriber is not None:
await self._message_subscriber.disconnect()
self._message_subscriber = None
if self._default_task_queue is not None:
await self._default_task_queue.disconnect()
self._default_task_queue = None
await self._connection.close()
[docs] async def add_rpc_subscriber(self, subscriber, identifier=None):
msg_subscriber = await self.get_message_subscriber()
identifier = await msg_subscriber.add_rpc_subscriber(subscriber, identifier)
return identifier
[docs] async def remove_rpc_subscriber(self, identifier):
msg_subscriber = await self.get_message_subscriber()
await msg_subscriber.remove_rpc_subscriber(identifier)
[docs] async def add_task_subscriber(self, subscriber, identifier=None):
default_task_queue = await self.get_default_task_queue()
return await default_task_queue.add_task_subscriber(subscriber, identifier)
[docs] async def remove_task_subscriber(self, identifier):
default_task_queue = await self.get_default_task_queue()
await default_task_queue.remove_task_subscriber(identifier)
[docs] async def add_broadcast_subscriber(self, subscriber, identifier=None):
msg_subscriber = await self.get_message_subscriber()
identifier = await msg_subscriber.add_broadcast_subscriber(subscriber, identifier)
return identifier
[docs] async def remove_broadcast_subscriber(self, identifier):
msg_subscriber = await self.get_message_subscriber()
await msg_subscriber.remove_broadcast_subscriber(identifier)
[docs] async def rpc_send(self, recipient_id, msg):
"""Initiate a remote procedure call on a recipient.
:param recipient_id: The recipient identifier
:param msg: The body of the message
:return: A future corresponding to the outcome of the call
"""
try:
publisher = await self.get_message_publisher()
response_future = await publisher.rpc_send(recipient_id, msg)
return response_future
except aio_pika.exceptions.DeliveryError as exception:
raise kiwipy.UnroutableError(str(exception))
[docs] async def broadcast_send(self, body, sender=None, subject=None, correlation_id=None):
publisher = await self.get_message_publisher()
result = await publisher.broadcast_send(body, sender, subject, correlation_id)
return result
[docs] async def task_send(self, task, no_reply=False):
try:
task_queue = await self.get_default_task_queue()
result = await task_queue.task_send(task, no_reply)
return result
except aio_pika.exceptions.DeliveryError as exception:
raise kiwipy.UnroutableError(str(exception))
except aio_pika.exceptions.AMQPError as exception:
# Find out what the exception is when a nack is generated!
raise kiwipy.TaskRejected(str(exception))
[docs] async def task_queue(
self,
queue_name: str,
prefetch_size=defaults.TASK_PREFETCH_SIZE,
prefetch_count=defaults.TASK_PREFETCH_COUNT
) -> tasks.RmqTaskQueue:
"""Create a new task queue."""
queue = tasks.RmqTaskQueue(
self._connection,
exchange_name=self._task_exchange,
queue_name=queue_name,
decoder=self._decoder,
encoder=self._encoder,
prefetch_size=prefetch_size,
prefetch_count=prefetch_count,
testing_mode=self._testing_mode
)
await queue.connect()
self._task_queues.append(queue)
return queue
[docs] def _ensure_connected(self):
if not self.connected():
raise RuntimeError(
'The communicator is not connected, call connect() or use in a context to establish a connection.'
)
[docs]async def async_connect(
# Connection parameters
connection_params: Union[str, dict] = None,
connection_factory=aio_pika.connect_robust,
# Messages
message_exchange: str = defaults.MESSAGE_EXCHANGE,
queue_expires: int = defaults.QUEUE_EXPIRES,
# Tasks
task_exchange: str = defaults.TASK_EXCHANGE,
task_queue: str = defaults.TASK_QUEUE,
task_prefetch_size=defaults.TASK_PREFETCH_SIZE,
task_prefetch_count=defaults.TASK_PREFETCH_COUNT,
encoder=defaults.ENCODER,
decoder=defaults.DECODER,
testing_mode=False,
) -> RmqCommunicator:
# pylint: disable=too-many-arguments
"""Convenience method that returns a connected communicator.
:param connection_params: parameters that will be passed to the connection factory to create the connection
:param connection_factory: the factory method to open the aio_pika connection with
:param message_exchange: The name of the RMQ message exchange to use
:param queue_expires: the expiry time for standard queues in milliseconds. This is the time after which, if there
are no subscribers, a queue will automatically be deleted by RabbitMQ.
:param task_exchange: The name of the RMQ task exchange to use
:param task_queue: The name of the task queue to use
:param task_prefetch_count: the number of tasks this communicator can fetch simultaneously
:param task_prefetch_size: the total size of the messages that the default queue can fetch simultaneously
:param encoder: The encoder to call for encoding a message
:param decoder: The decoder to call for decoding a message
:param testing_mode: Run in testing mode: all queues and exchanges will be temporary
"""
connection_params = connection_params or {}
if isinstance(connection_params, dict):
connection = await connection_factory(**connection_params)
else:
connection = await connection_factory(connection_params)
communicator = RmqCommunicator(
connection=connection,
# Messages
message_exchange=message_exchange,
queue_expires=queue_expires,
# Tasks
task_exchange=task_exchange,
task_queue=task_queue,
task_prefetch_size=task_prefetch_size,
task_prefetch_count=task_prefetch_count,
encoder=encoder,
decoder=decoder,
testing_mode=testing_mode
)
await communicator.connect()
return communicator