# Copyright 2019 Nokia # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import os import os.path import sys import inspect from yarf.iniloader import INILoader from yarf.restresource import RestResource from yarf.versionhandler import VersionHandler from yarf.authentication.base_auth import BaseAuthMethod from yarf.exceptions import ConfigError import yarf.restfullogger as restlog class PluginLoader(object): def __init__(self, path, api, auth_method): self.logger = restlog.get_logger() self.plugin_class_type = RestResource self.auth_method = self._get_auth_method(auth_method) self.path = path self.api = api def get_module_dirs(self): files = os.listdir(self.path) modules = [] for f in files: if os.path.isdir("%s/%s"%(self.path, f)): modules.append("%s/%s"%(self.path, f)) return modules def _get_auth_method(self, authmethod): auth_class_module = None class_name = None try: auth_class_module, class_name = authmethod.rsplit('.', 1) except ValueError: error = "Cannot decode the authentication method from configuration file" self.logger.error(error) raise ConfigError(error) auth_classes = self._get_classes_wanted_classes(auth_class_module, [class_name], BaseAuthMethod) if auth_classes is None or auth_classes == []: error = "Cannot find the authentication class in provided module %s %s" % (auth_class_module, class_name) raise ConfigError(error) return auth_classes[0] def _get_classes_wanted_classes(self, module_name, wanted_modules, class_type): classes = [] try: __import__(module_name) except ImportError: self.logger.error("Failed import in %s, skipping", module_name) return None module = sys.modules[module_name] for obj_name in dir(module): # Skip objects that are meant to be private. if obj_name.startswith('_'): continue # Skip the same name that base class has elif obj_name == class_type.__name__: continue elif obj_name not in wanted_modules: continue itm = getattr(module, obj_name) if inspect.isclass(itm) and issubclass(itm, class_type): classes.append(itm) return classes def get_classes_from_dir(self, directory, wanted_modules): classes = [] if directory not in sys.path: sys.path.append(directory) for fname in os.listdir(directory): root, ext = os.path.splitext(fname) if ext != '.py' or root == '__init__': continue module_name = "%s" % (root) mod_classes = self._get_classes_wanted_classes(module_name, wanted_modules, self.plugin_class_type) if mod_classes: classes.extend(mod_classes) return classes def get_modules_from_dir(self, module_dir): modules = {} for f in os.listdir(module_dir): if not f.endswith(".ini"): continue root, _ = os.path.splitext(f) loader = INILoader("%s/%s" %(module_dir, f)) sections = loader.get_sections() modules[root] = {} for section in sections: handlers = loader.get_handlers(section) if handlers: modules[root][section] = handlers else: self.logger.error("Problem in the configuration file %s in section %s: No handlers found", f, section) return modules def get_auth_method(self): return self.auth_method() def get_modules(self): dirs = self.get_module_dirs() auth_class = self.auth_method() modules = [] for d in dirs: wanted_modules = self.get_modules_from_dir(d) for mod in wanted_modules.keys(): for api_version in wanted_modules[mod].keys(): classes = self.get_classes_from_dir(d, wanted_modules[mod][api_version]) if not classes: continue for c in classes: setattr(c, "subarea", mod) if getattr(c, "authentication_method", "EMPTY") == "EMPTY": setattr(c, "authentication_method", auth_class) if getattr(c, "api_versions", None): c.api_versions.append(api_version) else: setattr(c, "api_versions", [api_version]) for cls in classes: if cls not in modules: modules.append(cls) return modules def create_endpoints(self, handler): endpoint_list = [] for endpoint in handler.endpoints: for api_version in handler.api_versions: self.logger.debug("Registering /%s/%s/%s for %s", handler.subarea, api_version, endpoint, handler.__name__) endpoint_list.append("/%s/%s/%s"% (handler.subarea, api_version, endpoint)) self.api.add_resource(handler, *(endpoint_list)) def add_logger(self, handler): self.logger.info("Adding logger to: %s", handler.__name__) handler.logger = self.logger def init_handler(self, handler): self.add_logger(handler) handler.add_wrappers() self.create_endpoints(handler) handler.add_parser_arguments() def create_api_versionhandlers(self, handlers): apiversions = {} endpoint_list = [] for handler in handlers: subarea = handler.subarea if apiversions.get(subarea, False): for hapiversion in handler.api_versions: if hapiversion not in apiversions[subarea]: apiversions[subarea].append(hapiversion) else: apiversions[subarea] = handler.api_versions self.logger.debug("Registering /%s/apis for %s", subarea, subarea) endpoint_list.append("/%s/apis" % subarea) setattr(VersionHandler, "versions", apiversions) setattr(VersionHandler, "method_decorators", []) self.api.add_resource(VersionHandler, *(endpoint_list))