import asyncio import collections import contextlib import os from logging import getLogger from defence360agent.api import inactivity from defence360agent.contracts.messages import MessageType, Splittable from defence360agent.contracts.plugins import ( MessageSink, MessageSource, expect, ) from defence360agent.utils import recurring_check logger = getLogger(__name__) class Accumulate(MessageSink, MessageSource): PROCESSING_ORDER = MessageSink.ProcessingOrder.POST_PROCESS_MESSAGE DEFAULT_AGGREGATE_TIMEOUT = int( os.environ.get("IMUNIFY360_AGGREGATE_MESSAGES_TIMEOUT", 60) ) SHUTDOWN_SEND_TIMEOUT = int( os.environ.get("IMUNIFY360_AGGREGATE_SHUTDOWN_SEND_TIMEOUT", 50) ) def __init__( self, period=DEFAULT_AGGREGATE_TIMEOUT, shutdown_timeout=SHUTDOWN_SEND_TIMEOUT, **kwargs, ): super().__init__(**kwargs) self._period = period self._shutdown_timeout = shutdown_timeout self._data = collections.defaultdict(list) async def create_source(self, loop, sink): self._loop = loop self._sink = sink self._task = ( None if self._period == 0 else loop.create_task(recurring_check(self._period)(self._flush)()) ) async def create_sink(self, loop): self._loop = loop async def shutdown(self): try: await asyncio.wait_for(self.stop(), self._shutdown_timeout) except asyncio.TimeoutError: # Used logger.error to notify sentry logger.error( "Timeout (%ss) sending messages to server on shutdown.", self._shutdown_timeout, ) if self._task is not None: self._task.cancel() with contextlib.suppress(asyncio.CancelledError): await self._task async def stop(self): logger.info("Accumulate.stop cancel _task") if self._task is not None: self._task.cancel() with contextlib.suppress(asyncio.CancelledError): await self._task logger.info("Accumulate.stop wait lock") # send pending messages await self._flush() @expect(MessageType.Accumulatable) async def collect(self, message): list_types = ( message.LIST_CLASS if isinstance(message.LIST_CLASS, tuple) else (message.LIST_CLASS,) ) if message.do_accumulate(): with inactivity.track.task("accumulate"): for list_type in list_types: self._data[list_type].append(message) async def _flush(self): copy_data = self._data self._data = collections.defaultdict(list) for list_type, messages in copy_data.items(): batched = ( list_type.batched(messages) if issubclass(list_type, Splittable) else (messages,) ) for batch in batched: logger.info( f"Prepare {list_type.__name__}() " "for further processing" ) try: # FIXME: remove this try..except block after # we have forbidden to create Accumulatable class # without LIST_CLASS. await self._sink.process_message(list_type(items=batch)) except TypeError: logger.error("%s, %s", list_type, batch) raise