import functools import inspect from typing import Any, Callable, Dict, Set, Tuple from defence360agent.contracts.config import UserType from defence360agent.utils import Scope from .exceptions import RpcError _RPC_MARK = "__rpc_command" class DuplicateHandlerError(Exception): pass class NotCoroutineError(Exception): pass class Endpoints: """Endpoints class implements registration and lookup for functions implementing RPC calls.""" SCOPE = Scope.AV_IM360 APPLICABLE_USER_TYPES = set() # type: Set[str] __COMMAND_MAP = { UserType.ROOT: {}, UserType.NON_ROOT: {}, } # type: Dict[str, Dict] _subclasses = [] def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs) cls._subclasses.append(cls) @classmethod def get_active_endpoints(cls): # consider endpoint as active if it has at least one RPC call handler active_endpoints = [] for subcls in cls._subclasses: rpc_handlers = inspect.getmembers( subcls, lambda item: getattr(item, _RPC_MARK, None) ) if rpc_handlers: active_endpoints.append(subcls) return active_endpoints def __init__(self, sink): self._sink = sink @classmethod async def route_to_endpoint(cls, request, sink, user=UserType.ROOT) -> Any: """Find appropriate class and function within that class that implements processing for request based on supplied 'command' within. Call that (async) function and return its result. If target class/function for given request['command'] is not found then RpcError exception is raised.""" command = request["command"] key = tuple(command) if key not in cls.__COMMAND_MAP[user]: raise RpcError( 'Endpoint not found for RPC method "%s"' % " ".join(request["command"]) ) cls_handler, handler_name = cls.__COMMAND_MAP[user][key] handler = getattr(cls_handler(sink), handler_name) return await handler(**request["params"]) @classmethod def register_rpc_handlers(cls) -> None: """Registers RPC handlers for all functions within a class. Functions should be decorated with @bind('command', ...).""" for name in dir(cls): if name.startswith("_"): continue attr = getattr(cls, name) command = getattr(attr, _RPC_MARK, None) if command is None: continue if not inspect.iscoroutinefunction(attr): raise NotCoroutineError("Must be a coroutine") for user_type in cls.APPLICABLE_USER_TYPES: if command in cls.__COMMAND_MAP[user_type]: msg = ( "Duplicate handlers for command {} ({}): {} and {}" .format( command, user_type, cls.__COMMAND_MAP[user_type][command], attr, ) ) raise DuplicateHandlerError(msg) cls.__COMMAND_MAP[user_type][command] = (cls, name) @classmethod def reset_rpc_handlers(cls): """Clears all previously made registrations.""" for user_type in {UserType.NON_ROOT, UserType.ROOT}: cls.__COMMAND_MAP[user_type] = {} class CommonEndpoints(Endpoints): """Endpoints available both for root and non root users.""" APPLICABLE_USER_TYPES = {UserType.NON_ROOT, UserType.ROOT} class RootEndpoints(Endpoints): """Endpoints available only for root user.""" APPLICABLE_USER_TYPES = {UserType.ROOT} class UserOnlyEndpoints(Endpoints): """Endpoints available only for non root users.""" APPLICABLE_USER_TYPES = {UserType.NON_ROOT} LOOKUP_ASSIGNMENTS = functools.WRAPPER_ASSIGNMENTS + (_RPC_MARK,) def wraps( wrapped, assigned=LOOKUP_ASSIGNMENTS, updated=functools.WRAPPER_UPDATES ): """Decorator replacing functools.wraps for rpc handlers""" return functools.partial( functools.update_wrapper, wrapped=wrapped, assigned=assigned, updated=updated, ) def bind(*command): """Mark a function as processing RPC calls for command.""" def decorator(func): setattr(func, _RPC_MARK, command) return func return decorator