5834a3eab7e59eb6207e49abae5d56d3d175af30
[ta/config-manager.git] / cmdependencysort.py
1 # Copyright 2019 Nokia
2
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 #     http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14
15 from cmframework.apis.cmerror import CMError
16
17
18 class CMDependencySort(object):
19     def __init__(self, after=None, before=None):
20         if not after:
21             after = {}
22         if not before:
23             before = {}
24
25         self._entries = set()
26         self._before = before
27
28         self._convert_after_to_before(after)
29         self._find_all_entries()
30
31         self._sorted_entries = []
32
33     def _convert_after_to_before(self, after):
34         for entry, deps in after.iteritems():
35             self._entries.add(entry)
36             for dep in deps:
37                 dep_before = self._before.get(dep, None)
38                 if not dep_before:
39                     dep_before = []
40                     self._before[dep] = dep_before
41                 if entry not in dep_before:
42                     dep_before.append(entry)
43
44     def _find_all_entries(self):
45         for entry, deps in self._before.iteritems():
46             self._entries.add(entry)
47             for dep in deps:
48                 self._entries.add(dep)
49
50     def sort(self):
51         self._sort_entries()
52         return self._sorted_entries
53
54     def _sort_entries(self):
55         sorted_list = []
56         permanent_mark_list = []
57         for entry in self._entries:
58             if entry not in permanent_mark_list:
59                 self._visit(entry, sorted_list, permanent_mark_list)
60         self._sorted_entries = sorted_list
61
62     def _visit(self, entry, sorted_list, permanent_mark_list, temporary_mark_list=None):
63         if not temporary_mark_list:
64             temporary_mark_list = []
65
66         if entry in permanent_mark_list:
67             return
68
69         if entry in temporary_mark_list:
70             raise CMError('Cycle detected in dependencies ({})'.format(entry))
71
72         temporary_mark_list.append(entry)
73         for dep in self._before.get(entry, []):
74             self._visit(dep, sorted_list, permanent_mark_list, temporary_mark_list)
75         permanent_mark_list.append(entry)
76         sorted_list.insert(0, entry)