""" Simple unix socket RPC server implementation """ import asyncio import functools import io import json import os import select import socket import sys import time from contextlib import suppress from logging import getLogger from typing import Sequence from psutil import Process import sentry_sdk from defence360agent.api import inactivity from defence360agent.application import app from defence360agent.contracts.config import SimpleRpc as Config from defence360agent.feature_management.exceptions import ( FeatureManagementError, ) from defence360agent.internals.auth_protocol import UnixSocketAuthProtocol from defence360agent.model import tls_check from defence360agent.model.simplification import run_in_executor from defence360agent.utils import is_root_user from defence360agent.utils.buffer import LineBuffer from defence360agent.subsys.panels import hosting_panel from defence360agent.subsys.panels.base import InvalidTokenException from defence360agent.rpc_tools.exceptions import ( ResponseError, ServiceStateError, SocketError, ) from defence360agent.rpc_tools.lookup import Endpoints, UserType from defence360agent.rpc_tools.utils import is_running, rpc_is_running from defence360agent.rpc_tools.validate import ValidationError from defence360agent.rpc_tools import ERROR, SUCCESS, WARNING logger = getLogger(__name__) class RpcServiceState: # If need DB and agent should be running # e.g. on-demand scan RUNNING = "running" # Agent should be stopped STOPPED = "stopped" # It doesn't matter for operation running or stopping the agent # if agent is running - using socket, instead of direct communication ANY = "any" # No need DB and UI interaction # preferable for use direct instead any for execution external process # e.g. enable/disable plugins/features DIRECT = "direct" async def _execute_request(coro, method): try: result = await coro except ValidationError as e: result = { "result": WARNING, "messages": e.errors, } result.update(e.extra_data) return result except (PermissionError, FeatureManagementError) as e: msg, *args = e.args logger.error(msg, *args) return { "result": ERROR, "messages": [msg % tuple(args)], } except Exception as e: sentry_sdk.capture_exception(e) logger.exception( "Something went wrong while processing %s (%s)", method, str(e) ) return {"result": ERROR, "messages": str(e)} else: return {"result": SUCCESS, "messages": [], "data": result} def _apply_middleware(method, user): cb = Endpoints.route_to_endpoint if isinstance(method, (list, tuple)): hashable = tuple(method) common = app.MIDDLEWARE.get(None, []) specific = app.MIDDLEWARE.get(hashable, []) excluded = app.MIDDLEWARE_EXCLUDE.get(hashable, []) for mw, users in reversed(common + specific): if (user in users) and (mw not in excluded): logger.debug("Applying middleware %s", mw.__name__) cb = mw(cb) return cb def _find_uds_inodes(socket_path: str) -> Sequence[str]: """Find inodes corresponding to the unix domain socket path.""" with open( "/proc/net/unix", encoding=sys.getfilesystemencoding(), errors=sys.getfilesystemencodeerrors(), ) as file: return [line.split()[-2] for line in file if socket_path in line] class _RpcServerProtocol(UnixSocketAuthProtocol): def __init__(self, loop, sink, user): self._loop = loop self._sink = sink self.user = user self._transport = None self._buf = LineBuffer() def preprocess_data(self, data: str): decoded = json.loads(data) user_type, user_name = hosting_panel.HostingPanel().authenticate( self, decoded ) self.user = user_type if user_name is not None: decoded["params"]["user"] = user_name # add calling process try: calling_process = Process(self._pid).cmdline() except Exception as e: calling_process = [str(e)] decoded["calling_process"] = calling_process return decoded def data_received(self, data): self._buf.append(data.decode()) for msg in self._buf: try: result = self.preprocess_data(msg) method = result["command"] params = result["params"] logger.debug("Data received: %r", data) cb = _apply_middleware(method, self.user) # TODO: fix that there is no json flag in params self._loop.create_task( self._dispatch( method, params, cb(result, self._sink, self.user) ) ) except InvalidTokenException as e: # without events in Sentry logger.warning("Incorrect token provided") self._write_response({"result": ERROR, "messages": str(e)}) except Exception as e: logger.exception( "Something went wrong before processing %s", data.decode() ) self._write_response({"result": ERROR, "messages": str(e)}) async def _dispatch(self, method, params, coro): with inactivity.track.task("rpc_{}".format(method)): # route and save result to 'result' response = await _execute_request(coro, method) logger.info( "Response: method - {}, data - {}".format(method, response) ) self._write_response(response) def connection_lost(self, transport): self._transport = None def _write_response(self, data): if self._transport is None: logger.warning("Cannot send RPC response: connection lost.") return else: try: self._transport.write((json.dumps(data) + "\n").encode()) except Exception as e: logger.exception(e) # TODO: need to own message error def _check_socket_folder_permissions(socket_path): dir_name = os.path.dirname(socket_path) os.makedirs(dir_name, exist_ok=True) os.chmod(dir_name, 0o755) class RpcServer: SOCKET_PATH = Config.SOCKET_PATH USER = UserType.ROOT SOCKET_MODE = 0o700 @classmethod async def create(cls, loop, sink): _check_socket_folder_permissions(cls.SOCKET_PATH) if os.path.exists(cls.SOCKET_PATH): os.unlink(cls.SOCKET_PATH) server = await loop.create_unix_server( lambda: _RpcServerProtocol(loop, sink, cls.USER), cls.SOCKET_PATH ) os.chmod(cls.SOCKET_PATH, cls.SOCKET_MODE) return server class RpcServerAV: USER = UserType.ROOT SOCKET_PATH = Config.SOCKET_PATH PROTOCOL_CLASS = _RpcServerProtocol @classmethod async def create(cls, loop, sink): """Looking for socket in /proc/net/unix and check which descriptor corresponded to it by comparing inode $ ls -l /proc/[pid]/fd lrwx------ 1 root root 64 Apr 11 07:20 4 -> socket:[2866765] $ cat /proc/net/unix Num RefCount Protocol Flags Type St Inode Path ffff880054c0a4c0: 00000002 00000000 00010000 0001 01 2866765 /var/run/defence360agent/simple_rpc.sock # noqa """ def safe_readlink(*args, **kwargs): """Return empty path on error.""" with suppress(OSError): return os.readlink(*args, **kwargs) return "" # find inodes for the SOCKET_PATH _socket_path = cls.SOCKET_PATH _check_socket_folder_permissions(_socket_path) if _socket_path.startswith("/var/run"): # remove /var prefix, see DEF-16201 _socket_path = _socket_path[len("/var") :] inodes = _find_uds_inodes(_socket_path) # find socket fds corresponding to the inodes last_error = None for inode in inodes: try: with os.scandir("/proc/self/fd") as it: for fd in it: if safe_readlink(fd.path) == "socket:[{}]".format( inode ): socket_fd = int(fd.name) break # found fd else: # no break, not found fd for given inode continue # try another inode break # found fd except OSError as e: last_error = e else: # no break, not found raise SocketError( "[{}] Socket {!r} for {} not found.".format( "inode" * (not inodes), cls.SOCKET_PATH, cls.USER ) ) from last_error _socket = socket.fromfd( socket_fd, socket.AF_UNIX, socket.SOCK_STREAM | socket.SOCK_NONBLOCK, ) server = await loop.create_unix_server( lambda: cls.PROTOCOL_CLASS(loop, sink, cls.USER), sock=_socket ) return server class NonRootRpcServerAV(RpcServerAV): USER = UserType.NON_ROOT SOCKET_PATH = Config.NON_ROOT_SOCKET_PATH class NonRootRpcServer(RpcServer): SOCKET_PATH = Config.NON_ROOT_SOCKET_PATH USER = UserType.NON_ROOT SOCKET_MODE = 0o777 class _RpcClientImpl: def __init__(self, socket_path): try: self._sock = socket.socket( socket.AF_UNIX, socket.SOCK_STREAM | socket.SOCK_NONBLOCK ) self._sock.connect(socket_path) except (ConnectionRefusedError, FileNotFoundError, BlockingIOError): raise ServiceStateError() def dispatch(self, method, params): self._sock.sendall( (json.dumps({"command": method, "params": params}) + "\n").encode() ) try: data = self._sock_recv_until(terminator_byte=b"\n") except ConnectionResetError as e: raise ResponseError("Connection reset: %s".format(e)) from e try: response = json.loads(data.decode()) except Exception as e: raise ResponseError( "Error parsing RPC response {!r}".format(data) ) from e return response def _sock_recv_until(self, terminator_byte): assert not self._sock.getblocking() chunks = [] while (not chunks) or (terminator_byte not in chunks[-1]): fdread_list = [self._sock.fileno()] rwx_fdlist = select.select( fdread_list, [], [], # naive timeout for one-shot response # scenario Config.CLIENT_TIMEOUT, ) fdready_list = rwx_fdlist[0] if self._sock.fileno() not in fdready_list: if any(rwx_fdlist): raise SocketError( "select() = {!r} resulted in error".format(rwx_fdlist) ) else: raise SocketError("request timeout") chunk = self._sock.recv(io.DEFAULT_BUFFER_SIZE) if len(chunk) == 0: raise SocketError("Empty response from socket.recv()") chunks.append(chunk) return b"".join(chunks) class _NoRpcImpl: def __init__(self, sink=None): self._sink = sink # suppress is for doing those things idempotent way # PSSST! simplification.run_in_executor() is main thread now! :-X # with suppress(tls_check.OverridingReset): # tls_check.reset("main CLI thread for stopped agent") with suppress(tls_check.OverridingReset): loop = asyncio.get_event_loop() loop.run_until_complete(run_in_executor(loop, tls_check.reset)) def dispatch(self, method, params): loop = asyncio.get_event_loop() logger.info("Executing {}, params: {}".format(method, params)) request = {"command": method, "params": params} cb = _apply_middleware(method, user=UserType.ROOT) return loop.run_until_complete( _execute_request(cb(request, self._sink), method) ) class RpcClient: """ One RpcClient instance is suitable to use for multiple ipc calls :param RpcServiceState require_svc_is_running: whether to provide direct endpoints binding if the service is stopped. :param int reconnect_with_timeout: timeout in sec for reconnect retries :param int num_retries: number of reconnect retries """ def __init__( self, *, require_svc_is_running=RpcServiceState.RUNNING, reconnect_with_timeout=None, num_retries=1 ): self._impl = None self._socket_path = ( Config.SOCKET_PATH if is_root_user() else Config.NON_ROOT_SOCKET_PATH ) if ( require_svc_is_running == RpcServiceState.STOPPED and rpc_is_running() ): raise ServiceStateError(RpcServiceState.RUNNING) if require_svc_is_running in ( RpcServiceState.ANY, RpcServiceState.RUNNING, ): try: if reconnect_with_timeout: self._impl = self._reconnect_with_timeout( reconnect_with_timeout, num_retries ) else: self._impl = _RpcClientImpl(self._socket_path) return except ServiceStateError: if require_svc_is_running == RpcServiceState.RUNNING: raise if self._impl is None: # In other cases (ANY, STOPPED, DIRECT) need to use _NoRpcImpl assert ( is_root_user() ), "_NoRpcImpl is not available for non root user" self._impl = _NoRpcImpl() def __getattr__(self, method): return functools.partial(self._dispatch, method) def cmd(self, *command): return functools.partial(self._dispatch, command) def _dispatch(self, method, **params): response = self._impl.dispatch(method, params) if isinstance(method, (list, tuple)): if response["result"] in (ERROR, WARNING): return response["result"], response["messages"] else: assert response["result"] == SUCCESS return response["result"], response["data"] else: if response["result"] in (ERROR, WARNING): raise ResponseError(response["messages"]) return response["data"] def _reconnect_with_timeout(self, timeout, num_retries): while True: try: return _RpcClientImpl(self._socket_path) except ServiceStateError: if num_retries: logger.info( "Waiting %d second(s) before retry...", timeout ) time.sleep(timeout) num_retries -= 1 else: raise