3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
7 # http://www.apache.org/licenses/LICENSE-2.0
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
20 from yarf.iniloader import INILoader
21 from yarf.restresource import RestResource
22 from yarf.versionhandler import VersionHandler
23 from yarf.authentication.base_auth import BaseAuthMethod
24 from yarf.exceptions import ConfigError
25 import yarf.restfullogger as restlog
27 class PluginLoader(object):
28 def __init__(self, path, api, auth_method):
29 self.logger = restlog.get_logger()
30 self.plugin_class_type = RestResource
31 self.auth_method = self._get_auth_method(auth_method)
35 def get_module_dirs(self):
36 files = os.listdir(self.path)
39 if os.path.isdir("%s/%s"%(self.path, f)):
40 modules.append("%s/%s"%(self.path, f))
43 def _get_auth_method(self, authmethod):
44 auth_class_module = None
47 auth_class_module, class_name = authmethod.rsplit('.', 1)
49 error = "Cannot decode the authentication method from configuration file"
50 self.logger.error(error)
51 raise ConfigError(error)
52 auth_classes = self._get_classes_wanted_classes(auth_class_module, [class_name], BaseAuthMethod)
53 if auth_classes is None or auth_classes == []:
54 error = "Cannot find the authentication class in provided module %s %s" % (auth_class_module, class_name)
55 raise ConfigError(error)
56 return auth_classes[0]
58 def _get_classes_wanted_classes(self, module_name, wanted_modules, class_type):
61 __import__(module_name)
63 self.logger.error("Failed import in %s, skipping", module_name)
65 module = sys.modules[module_name]
66 for obj_name in dir(module):
67 # Skip objects that are meant to be private.
68 if obj_name.startswith('_'):
70 # Skip the same name that base class has
71 elif obj_name == class_type.__name__:
73 elif obj_name not in wanted_modules:
75 itm = getattr(module, obj_name)
76 if inspect.isclass(itm) and issubclass(itm, class_type):
80 def get_classes_from_dir(self, directory, wanted_modules):
82 if directory not in sys.path:
83 sys.path.append(directory)
84 for fname in os.listdir(directory):
85 root, ext = os.path.splitext(fname)
86 if ext != '.py' or root == '__init__':
88 module_name = "%s" % (root)
90 mod_classes = self._get_classes_wanted_classes(module_name, wanted_modules, self.plugin_class_type)
92 classes.extend(mod_classes)
95 def get_modules_from_dir(self, module_dir):
97 for f in os.listdir(module_dir):
98 if not f.endswith(".ini"):
100 root, _ = os.path.splitext(f)
101 loader = INILoader("%s/%s" %(module_dir, f))
102 sections = loader.get_sections()
104 for section in sections:
105 handlers = loader.get_handlers(section)
107 modules[root][section] = handlers
109 self.logger.error("Problem in the configuration file %s in section %s: No handlers found", f, section)
112 def get_auth_method(self):
113 return self.auth_method()
115 def get_modules(self):
116 dirs = self.get_module_dirs()
117 auth_class = self.auth_method()
120 wanted_modules = self.get_modules_from_dir(d)
121 for mod in wanted_modules.keys():
122 for api_version in wanted_modules[mod].keys():
123 classes = self.get_classes_from_dir(d, wanted_modules[mod][api_version])
127 setattr(c, "subarea", mod)
128 if getattr(c, "authentication_method", "EMPTY") == "EMPTY":
129 setattr(c, "authentication_method", auth_class)
130 if getattr(c, "api_versions", None):
131 c.api_versions.append(api_version)
133 setattr(c, "api_versions", [api_version])
135 if cls not in modules:
139 def create_endpoints(self, handler):
141 for endpoint in handler.endpoints:
142 for api_version in handler.api_versions:
143 self.logger.debug("Registering /%s/%s/%s for %s", handler.subarea, api_version, endpoint, handler.__name__)
144 endpoint_list.append("/%s/%s/%s"% (handler.subarea, api_version, endpoint))
145 self.api.add_resource(handler, *(endpoint_list))
147 def add_logger(self, handler):
148 self.logger.info("Adding logger to: %s", handler.__name__)
149 handler.logger = self.logger
151 def init_handler(self, handler):
153 self.add_logger(handler)
154 handler.add_wrappers()
155 self.create_endpoints(handler)
156 handler.add_parser_arguments()
158 def create_api_versionhandlers(self, handlers):
161 for handler in handlers:
162 subarea = handler.subarea
163 if apiversions.get(subarea, False):
164 for hapiversion in handler.api_versions:
165 if hapiversion not in apiversions[subarea]:
166 apiversions[subarea].append(hapiversion)
168 apiversions[subarea] = handler.api_versions
169 self.logger.debug("Registering /%s/apis for %s", subarea, subarea)
170 endpoint_list.append("/%s/apis" % subarea)
171 setattr(VersionHandler, "versions", apiversions)
172 setattr(VersionHandler, "method_decorators", [])
173 self.api.add_resource(VersionHandler, *(endpoint_list))