import argparse import ipaddress import sys from functools import lru_cache, partial from itertools import chain from typing import Any, Dict, Iterable, Iterator, Mapping, Tuple from defence360agent.application import app from defence360agent.contracts.config import Core as Config from defence360agent.rpc_tools.utils import prepare_schema from defence360agent.simple_rpc import RpcClient from defence360agent.utils.cli import EXITCODE_NOT_FOUND class SchemaToArgparse: # NOTE: 'default' is a normalization rule, 'required' is a validation rule OptionType = Iterator[Tuple[str, Any]] def __init__(self, argument, options): self._argument: str = argument self._allowed: Iterable = options.get("allowed") self._default: Any = options.get("default") self._envvar: str = options.get("envvar", False) self._help: str = options.get("help") self._positional: bool = options.get("positional", False) self._rename: str = options.get("rename") self._required: bool = options.get("required", False) self._type: str = options.get("type") @property def argname(self) -> str: if self._positional: return self._argument return "--" + self._argument.replace("_", "-") @property def options(self): argparse_options = dict( chain( self.choices(), self.default(), self.help(), self.metavar(), self.nargs(), self.required(), ), ) return argparse_options def nargs(self) -> OptionType: option = "nargs" if self._type == "list": # FIXME: all positional arguments are not required # to support `rename` if not self._positional and self._required and not self._envvar: yield option, "+" else: yield option, "*" elif self._positional and (self._envvar or self._default is None): yield option, "?" def choices(self) -> OptionType: yield "choices", self._allowed def help(self) -> OptionType: yield "help", self._help def metavar(self) -> OptionType: option = "metavar" if self._rename: yield option, self._rename.upper() elif self._type == "list": yield option, self._argument.upper() def default(self) -> OptionType: if ( self._default is not None and not self._envvar and (self._type == "list" or not self._positional) ): yield "default", self._default def required(self): if ( self._required and self._type != "list" and not self._envvar # 'required' is an invalid argument for positionals and not self._positional ): yield "required", True def schema_to_argparse(parser, argument, options): if options.get("type") == "boolean": required = options.get("required") and not options.get("envvar", False) bool_parser = parser.add_mutually_exclusive_group(required=required) bool_parser.add_argument( "--" + argument.replace("_", "-"), dest=argument, action="store_true", ) bool_parser.add_argument( "--no-" + argument.replace("_", "-"), dest=argument, action="store_false", ) bool_parser.set_defaults(**{argument: options.get("default")}) else: converter = SchemaToArgparse(argument, options) parser.add_argument(converter.argname, **converter.options) class EnvParser: @staticmethod def format_help(envvar_parameter_options: Mapping): if not envvar_parameter_options: return "" def format_arg(options): if "help" in options: return f"{options['envvar']}\t\t{options['help']}" return options["envvar"] return "\nenvironment variables: \n {}".format( "\n ".join( format_arg(options) for options in envvar_parameter_options.values() ) ) @staticmethod def _validate(envvar, value, options): if "isascii" in options: try: value.encode("ascii") except UnicodeEncodeError: return ( f"error: {envvar}={value} must only contain ascii symbols", ) return None @classmethod def parse( cls, environ: Mapping, command, envvar_parameter_options, exclude: Iterable[str], ) -> Dict[str, str]: kwargs = {} for parameter, options in envvar_parameter_options.items(): if parameter in exclude: continue envvar_name = options["envvar"] try: value = kwargs[parameter] = environ[envvar_name] except KeyError: if "default" in options: kwargs[parameter] = options["default"] continue if not options.get("required"): continue msg = cls._format_error( command, envvar_parameter_options, "error: environment variable {} is not defined".format( envvar_name ), ) print(msg, file=sys.stderr) sys.exit(EXITCODE_NOT_FOUND) else: if err := cls._validate(envvar_name, value, options): msg = cls._format_error( command, envvar_parameter_options, err ) print(msg, file=sys.stderr) sys.exit(EXITCODE_NOT_FOUND) return kwargs @classmethod def _format_error(cls, command, envvar_parameter_options, msg): return "{command}:\n{help}\n\n{message}".format( command=" ".join(command), help=cls.format_help(envvar_parameter_options), message=msg, ) def is_valid_ipv4_addr(addr): try: ipaddress.IPv4Address(addr) except ipaddress.AddressValueError: return False return True def _filter_user(schema, user): for key, values in schema.items(): if user in values.get("cli", {}).get("users", []): yield key, values def rpc_endpoint(command, require_rpc, **params): return RpcClient(require_svc_is_running=require_rpc).cmd(*command)( **params ) def generate_endpoint_params(arg_parser_namespace, arguments): kwargs = {} for argument in arguments: arg_parser_argument = argument.replace("-", "_") value = getattr(arg_parser_namespace, arg_parser_argument, None) if value is not None: kwargs[argument] = value return kwargs def apply_parser(subparsers, schema): _subparsers = {} commands = sorted(schema.keys()) for methods in commands: values = schema[methods] assert isinstance(methods, (tuple, list)) parser = None # generate subparsers subparser = subparsers for i, command in enumerate(methods): # last element if i == len(methods) - 1: parser = subparser.add_parser( name=command, help=values.get("help"), formatter_class=argparse.RawDescriptionHelpFormatter, ) if any( (c != methods and methods == c[: len(methods)]) for c in commands ): _subparsers[methods] = parser.add_subparsers( help="Available commands" ) else: # Need to reuse created subparsers for sub-commands, otherwise # they will be overwritten. # # Example: # For both of the commands: # * malware on-demand queue put # * malware on-demand queue remove # only one subparser is created. We should add `queue` # subparser only once in order to keep both `put` and `remove`. hashable = tuple(methods[: i + 1]) exists_subparser = _subparsers.get(hashable) if not exists_subparser: subparser = _subparsers[hashable] = subparser.add_parser( name=command, help=values.get("help"), ).add_subparsers(help="Available commands") else: subparser = exists_subparser assert parser, "parser is not defined" # generate arguments envvar_parameter_options = {} for argument, options in values.get("schema", {}).items(): if "envvar" in options: envvar_parameter_options[argument] = options if options.get("envvar_only", False): continue if "rename" in options: options.update(**values["schema"][options["rename"]]) options["required"] = False options["positional"] = False schema_to_argparse(parser, argument, options) parser.epilog = EnvParser.format_help(envvar_parameter_options) parser.add_argument( "--json", action="store_true", help="return data in JSON format" ) parser.add_argument("--verbose", "-v", action="count") require_rpc = values.get("cli", {}).get("require_rpc", "running") parser.set_defaults( # Initializing `RpcClient` here for each command will # inevitably lead to the `ServiceStateError`, # because some endpoints require the agent to be stopped and # some require it to be running. So we use `partial` to # defer initialization until the command is selected. endpoint=partial(rpc_endpoint, methods, require_rpc), generate_endpoint_params=partial( generate_endpoint_params, arguments=values.get("schema", {}).keys(), ), envvar_parameter_options=envvar_parameter_options, command=methods, ) def _apply_subparsers(subparsers, user): schema = dict(_filter_user(prepare_schema(app.SCHEMA_PATHS), user)) apply_parser(subparsers, schema) @lru_cache(maxsize=1) def create_cli_parser(): parser = argparse.ArgumentParser(description="CLI for %s." % Config.NAME) parser.add_argument("--log-config", help="logging config filename") parser.add_argument( "--console-log-level", choices=["ERROR", "WARNING", "INFO", "DEBUG"], help="Level of logging input to the console", ) parser.add_argument( "--remote-addr", type=lambda ip: ip if is_valid_ipv4_addr(ip) else None, help="Client's IP address for adding it to the whitelist", ) subparsers = parser.add_subparsers(help="Available commands") _apply_subparsers(subparsers, "root") return parser