"""
This program is free software: you can redistribute it and/or modify it under
the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License,
or (at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
See the GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see .
Copyright © 2019 Cloud Linux Software Inc.
This software is also available under ImunifyAV commercial license,
see
"""
import shutil
import time
from logging import getLogger
from typing import Dict, Optional, Union
from defence360agent.contracts.hook_events import HookEvent
from defence360agent.contracts.messages import MessageType
from defence360agent.contracts.plugins import (
MessageSink,
MessageSource,
expect,
)
from defence360agent.utils import Scope
from imav.malwarelib.config import (
MalwareScanResourceType,
MalwareScanType,
)
from imav.malwarelib.model import MalwareScan as MalwareScanModel
from imav.malwarelib.scan import (
ScanAlreadyCompleteError,
ScanInfoError,
)
from imav.malwarelib.scan.ai_bolit.detached import (
AiBolitDetachedScan,
)
from imav.malwarelib.scan.mds.detached import MDSDetachedScan
from imav.malwarelib.scan.queue_supervisor_sync import QueueSupervisorSync
from imav.malwarelib.scan.scan_result import aggregate_result
from imav.malwarelib.utils.user_list import fill_results_owner
logger = getLogger(__name__)
class DetachedScanPlugin(MessageSink, MessageSource):
PROCESSING_ORDER = MessageSink.ProcessingOrder.PRE_PROCESS_MESSAGE
SCOPE = Scope.AV
loop, sink = None, None
results_cache = {} # type: Dict[str, dict]
async def create_source(self, loop, sink):
self.loop = loop
self.sink = sink
async def create_sink(self, loop):
pass
@expect(MessageType.MalwareScan, async_lock=True)
async def complete_scan(self, message):
message_type = MalwareScanMessageInfo(message)
if not message_type.is_detached:
total_malicious = await self._count_total_malicious(message)
message["summary"]["total_malicious"] = total_malicious
return message
elif message_type.is_summary:
return await self._handle_summary(message)
# message_type.is_result
return await self._handle_results(message)
async def _handle_summary(self, message):
scan_id = message["summary"]["scanid"]
# If summary arrives after results, results are read from cache
if scan_id in self.results_cache:
message["summary"]["completed"] = time.time()
message["results"] = self.results_cache.pop(scan_id)
total_malicious = await self._count_total_malicious(message)
message["summary"]["total_malicious"] = total_malicious
queued_scan = QueueSupervisorSync.queue.find(
scanid=message["summary"]["scanid"]
)
if queued_scan:
QueueSupervisorSync.queue.remove(queued_scan)
await self._call_scan_finished_hook(
message["summary"], queued_scan.args if queued_scan else {}
)
return message
async def _handle_results(self, message):
message = await self.aggregate_result(message)
message_type = MalwareScanMessageInfo(message)
summary = message["summary"]
logger.info("Scan stopped")
queued_scan = QueueSupervisorSync.queue.find(scanid=summary["scanid"])
if message_type.summary_from_db is None:
if queued_scan:
summary["file_patterns"] = queued_scan.args["file_patterns"]
summary["exclude_patterns"] = queued_scan.args[
"exclude_patterns"
]
QueueSupervisorSync.queue.remove(queued_scan)
if summary.get("path") or summary.get("error"):
# Scan failed
summary["total_malicious"] = 0
await self._call_scan_finished_hook(summary, scan_args={})
return message
# Summary is not in DB yet, save results to cache
scan_id = message["summary"]["scanid"]
self.results_cache[scan_id] = message["results"]
# Report an error to Sentry if cache grows
cache_size = len(self.results_cache)
if cache_size > 1:
logger.error("MalwareScan cache size is %d", cache_size)
return
scan = message_type.summary_from_db
summary["scanid"] = scan.scanid
summary["path"] = scan.path
summary["started"] = scan.started
summary["completed"] = time.time()
if summary.get("total_files") is None:
summary["total_files"] = scan.total_resources
summary["type"] = scan.type
summary["error"] = summary.get("error", None)
message["summary"] = summary
total_malicious = await self._count_total_malicious(message)
message["summary"]["total_malicious"] = total_malicious
if queued_scan:
summary["file_patterns"] = queued_scan.args["file_patterns"]
summary["exclude_patterns"] = queued_scan.args["exclude_patterns"]
QueueSupervisorSync.queue.remove(queued_scan)
await self._call_scan_finished_hook(
summary, queued_scan.args if queued_scan else {}
)
return message
@staticmethod
async def _count_total_malicious(message) -> int:
return len(
[
k
for k, v in message["results"].items()
if v["hits"][0]["suspicious"] is False
]
)
async def _call_scan_finished_hook(self, summary, scan_args) -> None:
scan_finished = HookEvent.MalwareScanningFinished(
scan_id=summary["scanid"],
scan_type=summary["type"],
path=summary["path"],
started=summary["started"],
total_files=summary["total_files"],
total_malicious=summary["total_malicious"],
error=summary.get("error"),
status="failed" if summary.get("error") else "ok",
scan_params=scan_args,
stats={
**{
key: value
for key, value in summary.items()
if key
in ( # performance-related metrics
"scan_time",
"scan_time_hs",
"scan_time_preg",
"smart_time_hs",
"smart_time_preg",
"finder_time",
"cas_time",
"deobfuscate_time",
"mem_peak",
)
},
**{"total_files": summary["total_files"]},
},
)
await self.sink.process_message(scan_finished)
await self._recheck_scan_queue()
@staticmethod
def _get_detached_scan(
resource_type: Optional[Union[str, MalwareScanResourceType]], scan_id
):
return AiBolitDetachedScan(scan_id)
@expect(MessageType.MalwareScanComplete)
async def complete_detached_scan(self, message):
scan_id = message.get("scan_id")
resource_type = message.get("resource_type")
detached_scan = self._get_detached_scan(resource_type, scan_id)
try:
scan_message = await detached_scan.complete()
except ScanAlreadyCompleteError as err:
# This happens when AV is woken up by AiBolit. See DEF-11078.
logger.warning(
"Cannot complete scan %s, assuming it is already complete"
":\n%s",
scan_id,
err,
)
return
except ScanInfoError as err:
logger.error(
"Cannot complete %s scan %s, assuming it was not started:\n%s",
detached_scan.RESOURCE_TYPE.value,
scan_id,
err,
)
return
finally:
shutil.rmtree(str(detached_scan.detached_dir), ignore_errors=True)
await self.sink.process_message(scan_message)
@classmethod
async def aggregate_result(cls, message):
message["results"] = aggregate_result(message["results"])
await fill_results_owner(message["results"])
return message
async def _recheck_scan_queue(self):
await self.sink.process_message(MessageType.MalwareScanQueueRecheck())
class MalwareScanMessageInfo:
"""A helper class that allows to receive information about scan
from MalwareScan message.
"""
def __init__(self, message):
self.message = message
self._summary_from_db = None
self.scan_id = self.message["summary"]["scanid"]
@property
def is_detached(self):
summary = self.message["summary"]
return summary.get("type") in (
MalwareScanType.ON_DEMAND,
MalwareScanType.BACKGROUND,
MalwareScanType.USER,
None,
)
@property
def is_summary(self):
return self.message["results"] is None
@property
def summary_from_db(self):
if not self._summary_from_db:
summary_from_db = (
MalwareScanModel.select()
.where(MalwareScanModel.scanid == self.scan_id)
.limit(1)
)
if summary_from_db:
self._summary_from_db = summary_from_db[0]
return self._summary_from_db
class DetachedScanPluginIm360(DetachedScanPlugin):
SCOPE = Scope.IM360
@staticmethod
def _get_detached_scan(
resource_type: Optional[Union[str, MalwareScanResourceType]], scan_id
):
if resource_type is not None and (
MalwareScanResourceType(resource_type)
is MalwareScanResourceType.DB
):
return MDSDetachedScan(scan_id)
return AiBolitDetachedScan(scan_id)
@expect(MessageType.MalwareDatabaseScan)
async def complete_scan_db(self, message):
queued_scan = QueueSupervisorSync.queue.find(scanid=message["scan_id"])
if queued_scan:
QueueSupervisorSync.queue.remove(queued_scan)
scan_finished_event = HookEvent.MalwareScanningFinished(
scan_id=message["scan_id"],
scan_type=message["type"],
path=message["path"],
)
await self.sink.process_message(scan_finished_event)
await self._recheck_scan_queue()