# Authors: # Jason Gerard DeRose # # Copyright (C) 2008 Red Hat # see file 'COPYING' for use and warranty information # # This program is free software; you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # # You should have received a copy of the GNU General Public License # along with this program. If not, see . """ Plugin framework. The classes in this module make heavy use of Python container emulation. If you are unfamiliar with this Python feature, see http://docs.python.org/ref/sequence-types.html """ from collections.abc import Mapping import logging import operator import re import sys import threading import os from os import path import optparse # pylint: disable=deprecated-module import textwrap import collections import importlib import six from ipalib import errors from ipalib.config import Env from ipalib.text import _ from ipalib.util import classproperty from ipalib.base import ReadOnly, lock, islocked from ipalib.constants import DEFAULT_CONFIG from ipapython import ipa_log_manager, ipautil from ipapython.ipa_log_manager import ( LOGGING_FORMAT_FILE, LOGGING_FORMAT_STDERR) from ipapython.version import VERSION, API_VERSION, DEFAULT_PLUGINS if six.PY3: unicode = str logger = logging.getLogger(__name__) # FIXME: Updated constants.TYPE_ERROR to use this clearer format from wehjit: TYPE_ERROR = '%s: need a %r; got a %r: %r' # FIXME: This function has no unit test def find_modules_in_dir(src_dir): """ Iterate through module names found in ``src_dir``. """ if not (os.path.abspath(src_dir) == src_dir and os.path.isdir(src_dir)): return if os.path.islink(src_dir): return suffix = '.py' for name in sorted(os.listdir(src_dir)): if not name.endswith(suffix): continue pyfile = os.path.join(src_dir, name) if not os.path.isfile(pyfile): continue module = name[:-len(suffix)] if module == '__init__': continue yield module class Registry: """A decorator that makes plugins available to the API Usage:: register = Registry() @register() class obj_mod(...): ... For forward compatibility, make sure that the module-level instance of this object is named "register". """ def __init__(self): self.__registry = collections.OrderedDict() def __call__(self, **kwargs): def register(plugin): """ Register the plugin ``plugin``. :param plugin: A subclass of `Plugin` to attempt to register. """ if not callable(plugin): raise TypeError('plugin must be callable; got %r' % plugin) # Raise DuplicateError if this exact class was already registered: if plugin in self.__registry: raise errors.PluginDuplicateError(plugin=plugin) # The plugin is okay, add to __registry: self.__registry[plugin] = dict(kwargs, plugin=plugin) return plugin return register def __iter__(self): return iter(self.__registry.values()) class Plugin(ReadOnly): """ Base class for all plugins. """ version = '1' def __init__(self, api): assert api is not None self.__api = api self.__finalize_called = False self.__finalized = False self.__finalize_lock = threading.RLock() @classmethod def __name_getter(cls): # pylint: disable=unused-private-member, #4756 return cls.__name__ # you know nothing, pylint name = classproperty(__name_getter) @classmethod def __full_name_getter(cls): # pylint: disable=unused-private-member return '{}/{}'.format(cls.name, cls.version) full_name = classproperty(__full_name_getter) @classmethod def __bases_getter(cls): # pylint: disable=unused-private-member, #4756 return cls.__bases__ bases = classproperty(__bases_getter) @classmethod def __doc_getter(cls): # pylint: disable=unused-private-member, #4756 return cls.__doc__ doc = classproperty(__doc_getter) @classmethod def __summary_getter(cls): # pylint: disable=unused-private-member, #4756 doc = cls.doc if not _(doc).msg: return '<%s.%s>' % (cls.__module__, cls.__name__) else: return unicode(doc).split('\n\n', 1)[0].strip() summary = classproperty(__summary_getter) @property def api(self): """ Return `API` instance passed to `__init__()`. """ return self.__api # FIXME: for backward compatibility only @property def env(self): return self.__api.env # FIXME: for backward compatibility only @property def Backend(self): return self.__api.Backend # FIXME: for backward compatibility only @property def Command(self): return self.__api.Command def finalize(self): """ Finalize plugin initialization. This method calls `_on_finalize()` and locks the plugin object. Subclasses should not override this method. Custom finalization is done in `_on_finalize()`. """ with self.__finalize_lock: assert self.__finalized is False if self.__finalize_called: # No recursive calls! return self.__finalize_called = True self._on_finalize() self.__finalized = True if not self.__api.is_production_mode(): lock(self) def _on_finalize(self): """ Do custom finalization. This method is called from `finalize()`. Subclasses can override this method in order to add custom finalization. """ def ensure_finalized(self): """ Finalize plugin initialization if it has not yet been finalized. """ with self.__finalize_lock: if not self.__finalized: self.finalize() class finalize_attr: """ Create a stub object for plugin attribute that isn't set until the finalization of the plugin initialization. When the stub object is accessed, it calls `ensure_finalized()` to make sure the plugin initialization is finalized. The stub object is expected to be replaced with the actual attribute value during the finalization (preferably in `_on_finalize()`), otherwise an `AttributeError` is raised. This is used to implement on-demand finalization of plugin initialization. """ __slots__ = ('name', 'value') def __init__(self, name, value=None): self.name = name self.value = value def __get__(self, obj, cls): if obj is None or obj.api is None: return self.value obj.ensure_finalized() try: return getattr(obj, self.name) except RuntimeError: # If the actual attribute value is not set in _on_finalize(), # getattr() calls __get__() again, which leads to infinite # recursion. This can happen only if the plugin is written # badly, so advise the developer about that instead of giving # them a generic "maximum recursion depth exceeded" error. raise AttributeError( "attribute '%s' of plugin '%s' was not set in finalize()" % (self.name, obj.name) ) def __repr__(self): """ Return 'module_name.class_name()' representation. This representation could be used to instantiate this Plugin instance given the appropriate environment. """ return '%s.%s()' % ( self.__class__.__module__, self.__class__.__name__ ) class APINameSpace(Mapping): def __init__(self, api, base): self.__api = api self.__base = base self.__plugins = None self.__plugins_by_key = None def __enumerate(self): if self.__plugins is not None and self.__plugins_by_key is not None: return default_map = self.__api._API__default_map plugins = set() key_dict = self.__plugins_by_key = {} for plugin in self.__api._API__plugins: if not any(issubclass(b, self.__base) for b in plugin.bases): continue plugins.add(plugin) key_dict[plugin] = plugin key_dict[plugin.name, plugin.version] = plugin key_dict[plugin.full_name] = plugin if plugin.version == default_map.get(plugin.name, '1'): key_dict[plugin.name] = plugin self.__plugins = sorted(plugins, key=operator.attrgetter('full_name')) def __len__(self): self.__enumerate() return len(self.__plugins) def __contains__(self, key): self.__enumerate() return key in self.__plugins_by_key def __iter__(self): self.__enumerate() return iter(self.__plugins) def __dir__(self): # include plugins for readline tab completion and in dir() self.__enumerate() names = super().__dir__() names.extend(p.name for p in self) names.sort() return names def get_plugin(self, key): self.__enumerate() return self.__plugins_by_key[key] def __getitem__(self, key): plugin = self.get_plugin(key) return self.__api._get(plugin) def __call__(self): return six.itervalues(self) def __getattr__(self, key): try: return self[key] except KeyError: raise AttributeError(key) class API(ReadOnly): """ Dynamic API object through which `Plugin` instances are accessed. """ def __init__(self): super(API, self).__init__() self.__plugins = set() self.__plugins_by_key = {} self.__default_map = {} self.__instances = {} self.__next = {} self.__done = set() self.env = Env() @property def bases(self): raise NotImplementedError @property def packages(self): raise NotImplementedError def __len__(self): """ Return the number of plugin namespaces in this API object. """ return len(self.bases) def __iter__(self): """ Iterate (in ascending order) through plugin namespace names. """ return (base.__name__ for base in self.bases) def __contains__(self, name): """ Return True if this API object contains plugin namespace ``name``. :param name: The plugin namespace name to test for membership. """ return name in set(self) def __getitem__(self, name): """ Return the plugin namespace corresponding to ``name``. :param name: The name of the plugin namespace you wish to retrieve. """ if name in self: try: return getattr(self, name) except AttributeError: pass raise KeyError(name) def __call__(self): """ Iterate (in ascending order by name) through plugin namespaces. """ for name in self: try: yield getattr(self, name) except AttributeError: raise KeyError(name) def is_production_mode(self): """ If the object has self.env.mode defined and that mode is production return True, otherwise return False. """ return getattr(self.env, 'mode', None) == 'production' def __doing(self, name): if name in self.__done: raise Exception( '%s.%s() already called' % (self.__class__.__name__, name) ) self.__done.add(name) def __do_if_not_done(self, name): if name not in self.__done: getattr(self, name)() def isdone(self, name): return name in self.__done def bootstrap(self, parser=None, **overrides): """ Initialize environment variables and logging. """ self.__doing('bootstrap') self.env._bootstrap(**overrides) self.env._finalize_core(**dict(DEFAULT_CONFIG)) # Add the argument parser if not parser: parser = self.build_global_parser() self.parser = parser root_logger = logging.getLogger() # If logging has already been configured somewhere else (like in the # installer), don't add handlers or change levels: if root_logger.handlers or self.env.validate_api: return if self.env.debug: # pylint: disable=using-constant-test level = logging.DEBUG else: level = logging.INFO root_logger.setLevel(level) for attr in self.env: match = re.match(r'^log_logger_level_' r'(debug|info|warn|warning|error|critical|\d+)$', attr) if not match: continue level = ipa_log_manager.convert_log_level(match.group(1)) value = getattr(self.env, attr) regexps = re.split(r'\s*,\s*', value) # Add the regexp, it maps to the configured level for regexp in regexps: root_logger.addFilter(ipa_log_manager.Filter(regexp, level)) # Add stderr handler: level = logging.INFO if self.env.debug: # pylint: disable=using-constant-test level = logging.DEBUG else: if self.env.context == 'cli': if self.env.verbose > 0: level = logging.INFO else: level = logging.WARNING handler = logging.StreamHandler() handler.setLevel(level) handler.setFormatter(ipa_log_manager.Formatter(LOGGING_FORMAT_STDERR)) root_logger.addHandler(handler) # check after logging is set up but before we create files. fse = sys.getfilesystemencoding() if fse.lower() not in {'utf-8', 'utf8'}: raise errors.SystemEncodingError(encoding=fse) # Add file handler: if self.env.mode in ('dummy', 'unit_test'): return # But not if in unit-test mode if self.env.log is None: return log_dir = path.dirname(self.env.log) if not path.isdir(log_dir): try: os.makedirs(log_dir) except OSError: logger.error('Could not create log_dir %r', log_dir) return level = logging.INFO if self.env.debug: # pylint: disable=using-constant-test level = logging.DEBUG if self.env.log is not None: try: handler = logging.FileHandler(self.env.log) except IOError as e: logger.error('Cannot open log file %r: %s', self.env.log, e) else: handler.setLevel(level) handler.setFormatter( ipa_log_manager.Formatter(LOGGING_FORMAT_FILE) ) root_logger.addHandler(handler) def build_global_parser(self, parser=None, context=None): """ Add global options to an optparse.OptionParser instance. """ def config_file_callback(option, opt, value, parser): if not os.path.isfile(value): parser.error( _("%(filename)s: file not found") % dict(filename=value)) parser.values.conf = value if parser is None: parser = optparse.OptionParser( add_help_option=False, formatter=IPAHelpFormatter(), usage='%prog [global-options] COMMAND [command-options]', description='Manage an IPA domain', version=('VERSION: %s, API_VERSION: %s' % (VERSION, API_VERSION)), epilog='\n'.join([ 'See "ipa help topics" for available help topics.', 'See "ipa help " for more information on ' + 'a specific topic.', 'See "ipa help commands" for the full list of commands.', 'See "ipa --help" for more information on ' + 'a specific command.', ])) parser.disable_interspersed_args() parser.add_option("-h", "--help", action="help", help='Show this help message and exit') parser.add_option('-e', dest='env', metavar='KEY=VAL', action='append', help='Set environment variable KEY to VAL', ) parser.add_option('-c', dest='conf', metavar='FILE', action='callback', callback=config_file_callback, type='string', help='Load configuration from FILE.', ) parser.add_option('-d', '--debug', action='store_true', help='Produce full debuging output', ) parser.add_option('--delegate', action='store_true', help='Delegate the TGT to the IPA server', ) parser.add_option('-v', '--verbose', action='count', help='Produce more verbose output. A second -v displays the XML-RPC request', ) if context == 'cli': parser.add_option('-a', '--prompt-all', action='store_true', help='Prompt for ALL values (even if optional)' ) parser.add_option('-n', '--no-prompt', action='store_false', dest='interactive', help='Prompt for NO values (even if required)' ) parser.add_option('-f', '--no-fallback', action='store_false', dest='fallback', help='Only use the server configured in /etc/ipa/default.conf' ) return parser def bootstrap_with_global_options(self, parser=None, context=None): parser = self.build_global_parser(parser, context) (options, args) = parser.parse_args() overrides = {} if options.env is not None: assert type(options.env) is list for item in options.env: try: values = item.split('=', 1) except ValueError: # FIXME: this should raise an IPA exception with an # error code. # --Jason, 2008-10-31 pass if len(values) == 2: (key, value) = values overrides[str(key.strip())] = value.strip() else: raise errors.OptionError(_('Unable to parse option {item}' .format(item=item))) for key in ('conf', 'debug', 'verbose', 'prompt_all', 'interactive', 'fallback', 'delegate'): value = getattr(options, key, None) if value is not None: overrides[key] = value if hasattr(options, 'prod'): overrides['webui_prod'] = options.prod if context is not None: overrides['context'] = context self.bootstrap(parser, **overrides) return (options, args) def load_plugins(self): """ Load plugins from all standard locations. `API.bootstrap` will automatically be called if it hasn't been already. """ self.__doing('load_plugins') self.__do_if_not_done('bootstrap') if self.env.mode in ('dummy', 'unit_test'): return for package in self.packages: self.add_package(package) # FIXME: This method has no unit test def add_package(self, package): """ Add plugin modules from the ``package``. :param package: A package from which to add modules. """ package_name = package.__name__ package_file = package.__file__ package_dir = path.dirname(path.abspath(package_file)) parent = sys.modules[package_name.rpartition('.')[0]] parent_dir = path.dirname(path.abspath(parent.__file__)) if parent_dir == package_dir: raise errors.PluginsPackageError( name=package_name, file=package_file ) logger.debug("importing all plugin modules in %s...", package_name) modules = getattr(package, 'modules', find_modules_in_dir(package_dir)) modules = ['.'.join((package_name, mname)) for mname in modules] for name in modules: logger.debug("importing plugin module %s", name) try: module = importlib.import_module(name) except errors.SkipPluginModule as e: logger.debug("skipping plugin module %s: %s", name, e.reason) continue except Exception as e: tb = self.env.startup_traceback if tb: # pylint: disable=using-constant-test logger.exception("could not load plugin module %s", name) raise try: self.add_module(module) except errors.PluginModuleError as e: logger.debug("%s", e) def add_module(self, module): """ Add plugins from the ``module``. :param module: A module from which to add plugins. """ try: register = module.register except AttributeError: pass else: if isinstance(register, Registry): for kwargs in register: self.add_plugin(**kwargs) return raise errors.PluginModuleError(name=module.__name__) def add_plugin(self, plugin, override=False, no_fail=False): """ Add the plugin ``plugin``. :param plugin: A subclass of `Plugin` to attempt to add. :param override: If true, override an already added plugin. """ if not callable(plugin): raise TypeError('plugin must be callable; got %r' % plugin) # Find the base class or raise SubclassError: for base in plugin.bases: if issubclass(base, self.bases): break else: raise errors.PluginSubclassError( plugin=plugin, bases=self.bases, ) # Check override: prev = self.__plugins_by_key.get(plugin.full_name) if prev: if not override: if no_fail: return else: # Must use override=True to override: raise errors.PluginOverrideError( base=base.__name__, name=plugin.name, plugin=plugin, ) self.__plugins.remove(prev) self.__next[plugin] = prev else: if override: if no_fail: return else: # There was nothing already registered to override: raise errors.PluginMissingOverrideError( base=base.__name__, name=plugin.name, plugin=plugin, ) # The plugin is okay, add to sub_d: self.__plugins.add(plugin) self.__plugins_by_key[plugin.full_name] = plugin def finalize(self): """ Finalize the registration, instantiate the plugins. `API.bootstrap` will automatically be called if it hasn't been already. """ self.__doing('finalize') self.__do_if_not_done('load_plugins') if self.env.env_confdir is not None: if self.env.env_confdir == self.env.confdir: logger.info( "IPA_CONFDIR env sets confdir to '%s'.", self.env.confdir) for plugin in self.__plugins: if not self.env.validate_api: if plugin.full_name not in DEFAULT_PLUGINS: continue else: try: default_version = self.__default_map[plugin.name] except KeyError: pass else: # Technicall plugin.version is not an API version. The # APIVersion class can handle plugin versions. It's more # lean than pkg_resource.parse_version(). version = ipautil.APIVersion(plugin.version) default_version = ipautil.APIVersion(default_version) if version < default_version: continue self.__default_map[plugin.name] = plugin.version production_mode = self.is_production_mode() for base in self.bases: for plugin in self.__plugins: if not any(issubclass(b, base) for b in plugin.bases): continue if not self.env.plugins_on_demand: self._get(plugin) name = base.__name__ if not production_mode: assert not hasattr(self, name) setattr(self, name, APINameSpace(self, base)) for instance in six.itervalues(self.__instances): if not production_mode: assert instance.api is self if not self.env.plugins_on_demand: instance.ensure_finalized() if not production_mode: assert islocked(instance) if not production_mode: lock(self) def _get(self, plugin): if not callable(plugin): raise TypeError('plugin must be callable; got %r' % plugin) if plugin not in self.__plugins: raise KeyError(plugin) try: instance = self.__instances[plugin] except KeyError: instance = self.__instances[plugin] = plugin(self) return instance def get_plugin_next(self, plugin): if not callable(plugin): raise TypeError('plugin must be callable; got %r' % plugin) return self.__next[plugin] class IPAHelpFormatter(optparse.IndentedHelpFormatter): def format_epilog(self, text): text_width = self.width - self.current_indent indent = " " * self.current_indent lines = text.splitlines() result = '\n'.join( textwrap.fill(line, text_width, initial_indent=indent, subsequent_indent=indent) for line in lines) return '\n%s\n' % result