Source code for kiwipy.rmq.tasks

# -*- coding: utf-8 -*-
import asyncio
import collections
from contextlib import asynccontextmanager
import logging
from typing import Generator, Optional
import uuid
import weakref

import aio_pika
import shortuuid

import kiwipy

from . import defaults, messages, utils

_LOGGER = logging.getLogger(__name__)

__all__ = 'RmqTaskSubscriber', 'RmqTaskPublisher', 'RmqTaskQueue', 'RmqIncomingTask'

TaskInfo = collections.namedtuple('TaskBody', ('task', 'no_reply'))


[docs]class RmqTaskSubscriber(messages.BaseConnectionWithExchange): """ Listens for tasks coming in on the RMQ task queue """ TASK_QUEUE_ARGUMENTS = {'x-message-ttl': defaults.TASK_MESSAGE_TTL} def __init__( self, connection: aio_pika.Connection, exchange_name: str = defaults.MESSAGE_EXCHANGE, queue_name: str = defaults.TASK_QUEUE, testing_mode=False, decoder=defaults.DECODER, encoder=defaults.ENCODER, exchange_params=None, prefetch_size=defaults.TASK_PREFETCH_SIZE, prefetch_count=defaults.TASK_PREFETCH_COUNT ): # pylint: disable=too-many-arguments """ :param connection: An RMQ connection :param exchange_name: the name of the exchange to use :param queue_name: the name of the task queue to use :param decoder: A message decoder :param encoder: A response encoder """ super().__init__( connection, exchange_name=exchange_name, exchange_params=exchange_params, testing_mode=testing_mode ) self._task_queue_name = queue_name self._testing_mode = testing_mode self._decode = decoder self._encode = encoder self._prefetch_size = prefetch_size self._prefetch_count = prefetch_count self._consumer_tag = None self._task_queue = None # type: Optional[aio_pika.Queue] self._subscribers = {} self._pending_tasks = []
[docs] async def add_task_subscriber(self, subscriber, identifier=None): identifier = identifier or shortuuid.uuid() if identifier in self._subscribers: raise kiwipy.DuplicateSubscriberIdentifier(f"Task identifier '{identifier}'") self._subscribers[identifier] = subscriber if self._consumer_tag is None: self._consumer_tag = await self._task_queue.consume(self._on_task) return identifier
[docs] async def remove_task_subscriber(self, identifier): try: self._subscribers.pop(identifier) except KeyError as exception: raise ValueError(f"Unknown task subscriber '{identifier}'") from exception if not self._subscribers: await self._task_queue.cancel(self._consumer_tag) self._consumer_tag = None
[docs] async def connect(self): if self.channel(): # Already connected return await super().connect() await self.channel().set_qos(prefetch_count=self._prefetch_count, prefetch_size=self._prefetch_size) await self._create_task_queue()
async def __aiter__(self): tasks = [] try: while True: task = RmqIncomingTask(self, await self._task_queue.get(timeout=1.)) tasks.append(task) yield task except aio_pika.exceptions.QueueEmpty: return finally: # Put back any tasks that are still pending (i.e. not processed or to be processed) for task in tasks: if task.state == TASK_PENDING: await task.requeue()
[docs] @asynccontextmanager async def next_task(self, no_ack=False, fail=True, timeout=defaults.TASK_FETCH_TIMEOUT) -> Generator['RmqIncomingTask', None, None]: """ Get the next task from the queue. raises: kiwipy.exceptions.QueueEmpty: When the queue has no tasks within the timeout """ try: message = await self._task_queue.get(no_ack=no_ack, fail=fail, timeout=timeout) except aio_pika.exceptions.QueueEmpty as exc: raise kiwipy.exceptions.QueueEmpty(str(exc)) else: task = RmqIncomingTask(self, message) try: yield task finally: if task.state == TASK_PENDING: await task.requeue()
[docs] async def _create_task_queue(self): """Create and bind the task queue""" arguments = dict(self.TASK_QUEUE_ARGUMENTS) if self._testing_mode: arguments['x-expires'] = defaults.TEST_QUEUE_EXPIRES # x-expires means how long does the queue stay alive after no clients # x-message-ttl means what is the default ttl for a message arriving in the queue self._task_queue = await self._channel.declare_queue( name=self._task_queue_name, durable=not self._testing_mode, arguments=arguments ) await self._task_queue.bind(self._exchange, routing_key=self._task_queue.name)
[docs] async def _on_task(self, message: aio_pika.IncomingMessage): """ :param message: The aio_pika RMQ message """ # Decode the message tuple into a task body for easier use rmq_task = RmqIncomingTask(self, message) async with rmq_task.processing() as outcome: for subscriber in self._subscribers.values(): try: subscriber = utils.ensure_coroutine(subscriber) result = await subscriber(self, rmq_task.body) # If a task returns a future it is not considered done until the chain of # futures (i.e. if the first future resolves to a future and so on) finishes # and produces a concrete result while asyncio.isfuture(result): if not rmq_task.no_reply: await self._send_response(utils.pending_response(), message) result = await result except kiwipy.TaskRejected: # Task was rejected by this subscriber, keep trying continue except kiwipy.CancelledError: # The subscriber has cancelled their processing of the task outcome.cancel() except Exception as exc: # pylint: disable=broad-except # There was an exception during the processing of this task outcome.set_exception(exc) _LOGGER.exception('Exception occurred while processing task.') else: # All good outcome.set_result(result) break # Got handled
[docs] def _build_response_message(self, body, incoming_message): """ Create a aio_pika Message as a response to a task being deal with. :param body: The message body dictionary :type body: dict :param incoming_message: The original message we are responding to :type incoming_message: :class:`aio_pika.IncomingMessage` :return: The response message :rtype: :class:`aio_pika.Message` """ # Add host info body[utils.HOST_KEY] = utils.get_host_info() message = aio_pika.Message(body=self._encode(body), correlation_id=incoming_message.correlation_id) return message
[docs] async def _send_response(self, msg_body, incoming_message): msg = self._build_response_message(msg_body, incoming_message) await self._exchange.publish(msg, routing_key=incoming_message.reply_to)
TASK_PENDING = 'pending' TASK_FINISHED = 'finished' TASK_PROCESSING = 'processing' TASK_REQUEUED = 'requeued'
[docs]class RmqIncomingTask: def __init__(self, subscriber: RmqTaskSubscriber, message: aio_pika.IncomingMessage): self._subscriber = subscriber self._message = message self._task_info = TaskInfo(*subscriber._decode(message.body)) self._state = TASK_PENDING self._outcome_ref = None # type: Optional[weakref.ReferenceType] self._loop = self._subscriber.loop() @property def body(self) -> str: return self._task_info.task @property def no_reply(self) -> bool: return self._task_info.no_reply @property def state(self) -> str: return self._state
[docs] def process(self) -> asyncio.Future: if self._state != TASK_PENDING: raise asyncio.InvalidStateError(f'The task is {self._state}') self._state = TASK_PROCESSING outcome = self._loop.create_future() # Rely on the done callback to signal the end of processing outcome.add_done_callback(self._on_task_done) # Or the user lets the future get destroyed self._outcome_ref = weakref.ref(outcome, self._outcome_destroyed) return outcome
[docs] async def requeue(self): if self._state not in [TASK_PENDING, TASK_PROCESSING]: raise asyncio.InvalidStateError(f'The task is {self._state}') self._state = TASK_REQUEUED await self._message.nack(requeue=True) self._finalise()
[docs] @asynccontextmanager async def processing(self) -> Generator[asyncio.Future, None, None]: """Processing context. The task should be done at the end otherwise it's assumed the caller doesn't want to process it, and it's sent back to the queue""" if self._state != TASK_PENDING: raise asyncio.InvalidStateError(f'The task is {self._state}') self._state = TASK_PROCESSING outcome = self._loop.create_future() try: yield outcome except KeyboardInterrupt: # pylint: disable=try-except-raise raise except Exception as exc: # Set the exception on the task and re-raise so the client also sees it outcome.set_exception(exc) raise finally: if outcome.done(): await self._task_done(outcome) else: await self.requeue()
[docs] def _on_task_done(self, outcome): """Schedule a task to call ``_task_done`` when the outcome is done.""" self._loop.create_task(self._task_done(outcome))
[docs] async def _task_done(self, outcome: asyncio.Future): assert outcome.done() self._outcome_ref = None if outcome.cancelled(): # Whoever took the task decided not to process it self._state = TASK_PENDING else: # Task is done or excepted # Permanently store the outcome self._state = TASK_FINISHED await self._message.ack() # We have to get the result from the future here (even if not replying), otherwise # python complains that it was never retrieved in case of exception try: reply_body = utils.result_response(outcome.result()) except Exception as exc: # pylint: disable=broad-except reply_body = utils.exception_response(exc) if not self.no_reply: # Schedule a task to send the appropriate response # pylint: disable=protected-access await self._subscriber._send_response(reply_body, self._message) # Clean up self._finalise()
[docs] def _outcome_destroyed(self, outcome_ref): # This only happens if someone called self.process() and then let the future # get destroyed without setting an outcome assert outcome_ref is self._outcome_ref # This task will not be processed self._outcome_ref = None asyncio.run_coroutine_threadsafe(self.requeue(), loop=self._loop)
[docs] def _finalise(self): self._outcome_ref = None self._subscriber = None self._message = None
[docs]class RmqTaskPublisher(messages.BasePublisherWithReplyQueue): """ Publishes messages to the RMQ task queue and gets the response """ def __init__( self, connection, queue_name=defaults.TASK_QUEUE, exchange_name=defaults.MESSAGE_EXCHANGE, exchange_params=None, encoder=defaults.ENCODER, decoder=defaults.DECODER, confirm_deliveries=True, testing_mode=False ): # pylint: disable=too-many-arguments super().__init__( connection, exchange_name=exchange_name, exchange_params=exchange_params, encoder=encoder, decoder=decoder, confirm_deliveries=confirm_deliveries, testing_mode=testing_mode ) self._task_queue_name = queue_name
[docs] async def task_send(self, task, no_reply: bool = False) -> asyncio.Future: """Send a task for processing by a task subscriber. All task messages will be set to be persistent by setting `delivery_mode=2`. :param task: The task payload :param no_reply: Don't send a reply containing the result of the task :return: A future representing the result of the task """ _LOGGER.debug( 'Sending task with routing key %r to RMQ queue %r (reply=%r): %r', self._task_queue_name, self._reply_queue.name, not no_reply, task, ) # Build the full message body and encode as a tuple body = self._encode((task, no_reply)) # Now build up the full aio_pika message task_msg = aio_pika.Message( body=body, correlation_id=str(uuid.uuid4()), reply_to=self._reply_queue.name, delivery_mode=aio_pika.DeliveryMode.PERSISTENT # Task messages need to be persistent ) result_future = None if no_reply: published = await self.publish(task_msg, routing_key=self._task_queue_name, mandatory=True) else: published, result_future = await self.publish_expect_response( task_msg, routing_key=self._task_queue_name, mandatory=True ) assert published, 'The task was not published to the exchange' return result_future
[docs]class RmqTaskQueue: """Combines a task publisher and subscriber to create a work queue where you can do both""" def __init__( self, connection, exchange_name=defaults.MESSAGE_EXCHANGE, queue_name=defaults.TASK_QUEUE, decoder=defaults.DECODER, encoder=defaults.ENCODER, exchange_params=None, prefetch_size=defaults.TASK_PREFETCH_SIZE, prefetch_count=defaults.TASK_PREFETCH_COUNT, testing_mode=False ): # pylint: disable=too-many-arguments self._publisher = RmqTaskPublisher( connection, exchange_name=exchange_name, exchange_params=exchange_params, queue_name=queue_name, decoder=decoder, encoder=encoder, testing_mode=testing_mode ) self._subscriber = RmqTaskSubscriber( connection, exchange_name=exchange_name, exchange_params=exchange_params, queue_name=queue_name, decoder=decoder, encoder=encoder, prefetch_size=prefetch_size, prefetch_count=prefetch_count, testing_mode=testing_mode ) async def __aiter__(self): # Have to do it this way rather than the more convenient yield from style because # python doesn't support it for coroutines. See: # https://stackoverflow.com/questions/47376408/why-cant-i-yield-from-inside-an-async-function async for task in self._subscriber: yield task
[docs] async def task_send(self, task, no_reply: bool = False): """Send a task to the queue""" return await self._publisher.task_send(task, no_reply)
[docs] async def add_task_subscriber(self, subscriber, identifier=None): return await self._subscriber.add_task_subscriber(subscriber, identifier)
[docs] async def remove_task_subscriber(self, identifier): return await self._subscriber.remove_task_subscriber(identifier)
[docs] @asynccontextmanager async def next_task(self, no_ack=False, fail=True, timeout=defaults.TASK_FETCH_TIMEOUT): async with self._subscriber.next_task(no_ack=no_ack, fail=fail, timeout=timeout) as task: # pylint: disable=not-async-context-manager yield task
[docs] async def connect(self): await self._subscriber.connect() await self._publisher.connect()
[docs] async def disconnect(self): await self._subscriber.disconnect() await self._publisher.disconnect()