import logging import warnings from functools import wraps from defence360agent.contracts import eula from defence360agent.contracts.config import Core, UserType from defence360agent.contracts.license import LicenseCLN from defence360agent.contracts.messages import MessageType logger = logging.getLogger(__name__) def add_license(f): @wraps(f) async def wrapper(*args, **kwargs): result = await f(*args, **kwargs) assert isinstance(result, dict), ( "Result should be a dictionary %s" % result ) license = LicenseCLN.license_info() result["license"] = license return result return wrapper def add_license_user(f): @wraps(f) async def wrapper(*args, **kwargs): result = await f(*args, **kwargs) assert isinstance(result, dict), ( "Result should be a dictionary %s" % result ) license = LicenseCLN.license_info() result["license"] = { "status": license["status"], "license_type": license.get("license_type"), } return result return wrapper def add_eula(f): @wraps(f) async def wrapper(*args, **kwargs): result = await f(*args, **kwargs) assert isinstance(result, dict), ( "Result should be a dictionary %s" % result ) eula_dict = None # do not show eula if not registered or using free AV version if LicenseCLN.is_valid() and (not LicenseCLN.is_free()): if not await eula.is_accepted(): try: eula_dict = { "message": eula.message(), "text": eula.text(), "updated": eula.updated(), } except OSError as e: eula_dict = { "message": "Failed to read EULA", "text": "Failed to read EULA: {}".format(str(e)), "updated": "", } result["eula"] = eula_dict return result return wrapper def add_version(f): @wraps(f) async def wrapper(*args, **kwargs): result = await f(*args, **kwargs) assert isinstance(result, dict), ( "Result should be a dictionary %s" % result ) result["version"] = Core.VERSION return result return wrapper def max_count(f): @wraps(f) async def wrapper(*args, **kwargs): count, items = await f(*args, **kwargs) return {"max_count": count, "items": items} return wrapper def counts(f): @wraps(f) async def wrapper(*args, **kwargs): max_count, counts, items = await f(*args, **kwargs) return {"max_count": max_count, "counts": counts, "items": items} return wrapper def collect_warnings(f): @wraps(f) async def wrapper(*args, **kwargs): warnings.simplefilter("always", DeprecationWarning) with warnings.catch_warnings(record=True) as warns: result = await f(*args, **kwargs) result["warnings"] = [" ".join(w.message.args) for w in warns] return result return wrapper # Need only for backward compatibility def default_to_items(f): @wraps(f) async def wrapper(*args, **kwargs): result = await f(*args, **kwargs) if not isinstance(result, dict): result = {"items": result} return result return wrapper def preserve_remote_addr(f): """ This middleware copies 'remote_addr' to 'client_addr'. This is needed because send_command_invoke middleware may remove remote_addr parameter from request. Used for endpoints that need remote_addr in their logic. :param f: :return: """ @wraps(f) async def wrapper(request, *args, **kwargs): remote_addr = request["params"].get("remote_addr") request["client_addr"] = remote_addr return await f(request, *args, **kwargs) return wrapper def send_command_invoke_message(coro): @wraps(coro) async def wrapper(request, *args, **kwargs): # get the sink to send CommandInvoke message sink = None if args: sink = args[0] elif "sink" in kwargs: sink = kwargs["sink"] if sink is not None: params = dict(request["params"]) if "user" not in params: # find user type (root/non-root) to determine access rights user_type = None if len(args) > 1: user_type = args[1] elif "user" in kwargs: user_type = kwargs["user"] if user_type == UserType.NON_ROOT: params["user"] = True # don't send passwords if "password" in params: params["password"] = "***" # send message await sink.process_message( MessageType.CommandInvoke( command=request["command"], params=params, calling_process=request.pop("calling_process", None), ) ) request["params"].pop("remote_addr", None) return await coro(request, *args, **kwargs) return wrapper