Seed code for yarf
[ta/yarf.git] / src / yarf / handlers / pluginhandler.py
1 # Copyright 2019 Nokia
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 #     http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14 #
15
16 import os
17 import os.path
18 import sys
19 import inspect
20 from yarf.iniloader import INILoader
21 from yarf.restresource import RestResource
22 from yarf.versionhandler import VersionHandler
23 from yarf.authentication.base_auth import BaseAuthMethod
24 from yarf.exceptions import ConfigError
25 import yarf.restfullogger as restlog
26
27 class PluginLoader(object):
28     def __init__(self, path, api, auth_method):
29         self.logger = restlog.get_logger()
30         self.plugin_class_type = RestResource
31         self.auth_method = self._get_auth_method(auth_method)
32         self.path = path
33         self.api = api
34
35     def get_module_dirs(self):
36         files = os.listdir(self.path)
37         modules = []
38         for f in files:
39             if os.path.isdir("%s/%s"%(self.path, f)):
40                 modules.append("%s/%s"%(self.path, f))
41         return modules
42
43     def _get_auth_method(self, authmethod):
44         auth_class_module = None
45         class_name = None
46         try:
47             auth_class_module, class_name = authmethod.rsplit('.', 1)
48         except ValueError:
49             error = "Cannot decode the authentication method from configuration file"
50             self.logger.error(error)
51             raise ConfigError(error)
52         auth_classes = self._get_classes_wanted_classes(auth_class_module, [class_name], BaseAuthMethod)
53         if auth_classes is None or auth_classes == []:
54             error = "Cannot find the authentication class in provided module %s %s" % (auth_class_module, class_name)
55             raise ConfigError(error)
56         return auth_classes[0]
57
58     def _get_classes_wanted_classes(self, module_name, wanted_modules, class_type):
59         classes = []
60         try:
61             __import__(module_name)
62         except ImportError:
63             self.logger.error("Failed import in %s, skipping", module_name)
64             return None
65         module = sys.modules[module_name]
66         for obj_name in dir(module):
67             # Skip objects that are meant to be private.
68             if obj_name.startswith('_'):
69                 continue
70             # Skip the same name that base class has
71             elif obj_name == class_type.__name__:
72                 continue
73             elif obj_name not in wanted_modules:
74                 continue
75             itm = getattr(module, obj_name)
76             if inspect.isclass(itm) and issubclass(itm, class_type):
77                 classes.append(itm)
78         return classes
79
80     def get_classes_from_dir(self, directory, wanted_modules):
81         classes = []
82         if directory not in sys.path:
83             sys.path.append(directory)
84         for fname in os.listdir(directory):
85             root, ext = os.path.splitext(fname)
86             if ext != '.py' or root == '__init__':
87                 continue
88             module_name = "%s" % (root)
89
90             mod_classes = self._get_classes_wanted_classes(module_name, wanted_modules, self.plugin_class_type)
91             if mod_classes:
92                 classes.extend(mod_classes)
93         return classes
94
95     def get_modules_from_dir(self, module_dir):
96         modules = {}
97         for f in os.listdir(module_dir):
98             if not f.endswith(".ini"):
99                 continue
100             root, _ = os.path.splitext(f)
101             loader = INILoader("%s/%s" %(module_dir, f))
102             sections = loader.get_sections()
103             modules[root] = {}
104             for section in sections:
105                 handlers = loader.get_handlers(section)
106                 if handlers:
107                     modules[root][section] = handlers
108                 else:
109                     self.logger.error("Problem in the configuration file %s in section %s: No handlers found", f, section)
110         return modules
111
112     def get_auth_method(self):
113         return self.auth_method()
114
115     def get_modules(self):
116         dirs = self.get_module_dirs()
117         auth_class = self.auth_method()
118         modules = []
119         for d in dirs:
120             wanted_modules = self.get_modules_from_dir(d)
121             for mod in wanted_modules.keys():
122                 for api_version in wanted_modules[mod].keys():
123                     classes = self.get_classes_from_dir(d, wanted_modules[mod][api_version])
124                     if not classes:
125                         continue
126                     for c in classes:
127                         setattr(c, "subarea", mod)
128                         if getattr(c, "authentication_method", "EMPTY") == "EMPTY":
129                             setattr(c, "authentication_method", auth_class)
130                         if getattr(c, "api_versions", None):
131                             c.api_versions.append(api_version)
132                         else:
133                             setattr(c, "api_versions", [api_version])
134                     for cls in classes:
135                         if cls not in modules:
136                             modules.append(cls)
137         return modules
138
139     def create_endpoints(self, handler):
140         endpoint_list = []
141         for endpoint in handler.endpoints:
142             for api_version in handler.api_versions:
143                 self.logger.debug("Registering /%s/%s/%s for %s", handler.subarea, api_version, endpoint, handler.__name__)
144                 endpoint_list.append("/%s/%s/%s"% (handler.subarea, api_version, endpoint))
145         self.api.add_resource(handler, *(endpoint_list))
146
147     def add_logger(self, handler):
148         self.logger.info("Adding logger to: %s", handler.__name__)
149         handler.logger = self.logger
150
151     def init_handler(self, handler):
152
153         self.add_logger(handler)
154         handler.add_wrappers()
155         self.create_endpoints(handler)
156         handler.add_parser_arguments()
157
158     def create_api_versionhandlers(self, handlers):
159         apiversions = {}
160         endpoint_list = []
161         for handler in handlers:
162             subarea = handler.subarea
163             if apiversions.get(subarea, False):
164                 for hapiversion in handler.api_versions:
165                     if hapiversion not in apiversions[subarea]:
166                         apiversions[subarea].append(hapiversion)
167             else:
168                 apiversions[subarea] = handler.api_versions
169                 self.logger.debug("Registering /%s/apis for %s", subarea, subarea)
170                 endpoint_list.append("/%s/apis" % subarea)
171         setattr(VersionHandler, "versions", apiversions)
172         setattr(VersionHandler, "method_decorators", [])
173         self.api.add_resource(VersionHandler, *(endpoint_list))