pcb defect detetcion application
[ealt-edge.git] / example-apps / PDD / pcb-defect-detection / data / lib_coco / PythonAPI / pycocotools / coco.py
1 __author__ = 'tylin'
2 __version__ = '2.0'
3 # Interface for accessing the Microsoft COCO dataset.
4
5 # Microsoft COCO is a large image dataset designed for object detection,
6 # segmentation, and caption generation. pycocotools is a Python API that
7 # assists in loading, parsing and visualizing the annotations in COCO.
8 # Please visit http://mscoco.org/ for more information on COCO, including
9 # for the data, paper, and tutorials. The exact format of the annotations
10 # is also described on the COCO website. For example usage of the pycocotools
11 # please see pycocotools_demo.ipynb. In addition to this API, please download both
12 # the COCO images and annotations in order to run the demo.
13
14 # An alternative to using the API is to load the annotations directly
15 # into Python dictionary
16 # Using the API provides additional utility functions. Note that this API
17 # supports both *instance* and *caption* annotations. In the case of
18 # captions not all functions are defined (e.g. categories are undefined).
19
20 # The following API functions are defined:
21 #  COCO       - COCO api class that loads COCO annotation file and prepare data structures.
22 #  decodeMask - Decode binary mask M encoded via run-length encoding.
23 #  encodeMask - Encode binary mask M using run-length encoding.
24 #  getAnnIds  - Get ann ids that satisfy given filter conditions.
25 #  getCatIds  - Get cat ids that satisfy given filter conditions.
26 #  getImgIds  - Get img ids that satisfy given filter conditions.
27 #  loadAnns   - Load anns with the specified ids.
28 #  loadCats   - Load cats with the specified ids.
29 #  loadImgs   - Load imgs with the specified ids.
30 #  annToMask  - Convert segmentation in an annotation to binary mask.
31 #  showAnns   - Display the specified annotations.
32 #  loadRes    - Load algorithm results and create API for accessing them.
33 #  download   - Download COCO images from mscoco.org server.
34 # Throughout the API "ann"=annotation, "cat"=category, and "img"=image.
35 # Help on each functions can be accessed by: "help COCO>function".
36
37 # See also COCO>decodeMask,
38 # COCO>encodeMask, COCO>getAnnIds, COCO>getCatIds,
39 # COCO>getImgIds, COCO>loadAnns, COCO>loadCats,
40 # COCO>loadImgs, COCO>annToMask, COCO>showAnns
41
42 # Microsoft COCO Toolbox.      version 2.0
43 # Data, paper, and tutorials available at:  http://mscoco.org/
44 # Code written by Piotr Dollar and Tsung-Yi Lin, 2014.
45 # Licensed under the Simplified BSD License [see bsd.txt]
46
47 import json
48 import time
49 import matplotlib.pyplot as plt
50 from matplotlib.collections import PatchCollection
51 from matplotlib.patches import Polygon
52 import numpy as np
53 import copy
54 import itertools
55 from . import mask as maskUtils
56 import os
57 from collections import defaultdict
58 import sys
59 PYTHON_VERSION = sys.version_info[0]
60 if PYTHON_VERSION == 2:
61     from urllib import urlretrieve
62 elif PYTHON_VERSION == 3:
63     from urllib.request import urlretrieve
64
65
66 def _isArrayLike(obj):
67     return hasattr(obj, '__iter__') and hasattr(obj, '__len__')
68
69
70 class COCO:
71     def __init__(self, annotation_file=None):
72         """
73         Constructor of Microsoft COCO helper class for reading and visualizing annotations.
74         :param annotation_file (str): location of annotation file
75         :param image_folder (str): location to the folder that hosts images.
76         :return:
77         """
78         # load dataset
79         self.dataset,self.anns,self.cats,self.imgs = dict(),dict(),dict(),dict()
80         self.imgToAnns, self.catToImgs = defaultdict(list), defaultdict(list)
81         if not annotation_file == None:
82             print('loading annotations into memory...')
83             tic = time.time()
84             dataset = json.load(open(annotation_file, 'r'))
85             assert type(dataset)==dict, 'annotation file format {} not supported'.format(type(dataset))
86             print('Done (t={:0.2f}s)'.format(time.time()- tic))
87             self.dataset = dataset
88             self.createIndex()
89
90     def createIndex(self):
91         # create index
92         print('creating index...')
93         anns, cats, imgs = {}, {}, {}
94         imgToAnns,catToImgs = defaultdict(list),defaultdict(list)
95         if 'annotations' in self.dataset:
96             for ann in self.dataset['annotations']:
97                 imgToAnns[ann['image_id']].append(ann)
98                 anns[ann['id']] = ann
99
100         if 'images' in self.dataset:
101             for img in self.dataset['images']:
102                 imgs[img['id']] = img
103
104         if 'categories' in self.dataset:
105             for cat in self.dataset['categories']:
106                 cats[cat['id']] = cat
107
108         if 'annotations' in self.dataset and 'categories' in self.dataset:
109             for ann in self.dataset['annotations']:
110                 catToImgs[ann['category_id']].append(ann['image_id'])
111
112         print('index created!')
113
114         # create class members
115         self.anns = anns
116         self.imgToAnns = imgToAnns
117         self.catToImgs = catToImgs
118         self.imgs = imgs
119         self.cats = cats
120
121     def info(self):
122         """
123         Print information about the annotation file.
124         :return:
125         """
126         for key, value in self.dataset['info'].items():
127             print('{}: {}'.format(key, value))
128
129     def getAnnIds(self, imgIds=[], catIds=[], areaRng=[], iscrowd=None):
130         """
131         Get ann ids that satisfy given filter conditions. default skips that filter
132         :param imgIds  (int array)     : get anns for given imgs
133                catIds  (int array)     : get anns for given cats
134                areaRng (float array)   : get anns for given area range (e.g. [0 inf])
135                iscrowd (boolean)       : get anns for given crowd label (False or True)
136         :return: ids (int array)       : integer array of ann ids
137         """
138         imgIds = imgIds if _isArrayLike(imgIds) else [imgIds]
139         catIds = catIds if _isArrayLike(catIds) else [catIds]
140
141         if len(imgIds) == len(catIds) == len(areaRng) == 0:
142             anns = self.dataset['annotations']
143         else:
144             if not len(imgIds) == 0:
145                 lists = [self.imgToAnns[imgId] for imgId in imgIds if imgId in self.imgToAnns]
146                 anns = list(itertools.chain.from_iterable(lists))
147             else:
148                 anns = self.dataset['annotations']
149             anns = anns if len(catIds)  == 0 else [ann for ann in anns if ann['category_id'] in catIds]
150             anns = anns if len(areaRng) == 0 else [ann for ann in anns if ann['area'] > areaRng[0] and ann['area'] < areaRng[1]]
151         if not iscrowd == None:
152             ids = [ann['id'] for ann in anns if ann['iscrowd'] == iscrowd]
153         else:
154             ids = [ann['id'] for ann in anns]
155         return ids
156
157     def getCatIds(self, catNms=[], supNms=[], catIds=[]):
158         """
159         filtering parameters. default skips that filter.
160         :param catNms (str array)  : get cats for given cat names
161         :param supNms (str array)  : get cats for given supercategory names
162         :param catIds (int array)  : get cats for given cat ids
163         :return: ids (int array)   : integer array of cat ids
164         """
165         catNms = catNms if _isArrayLike(catNms) else [catNms]
166         supNms = supNms if _isArrayLike(supNms) else [supNms]
167         catIds = catIds if _isArrayLike(catIds) else [catIds]
168
169         if len(catNms) == len(supNms) == len(catIds) == 0:
170             cats = self.dataset['categories']
171         else:
172             cats = self.dataset['categories']
173             cats = cats if len(catNms) == 0 else [cat for cat in cats if cat['name']          in catNms]
174             cats = cats if len(supNms) == 0 else [cat for cat in cats if cat['supercategory'] in supNms]
175             cats = cats if len(catIds) == 0 else [cat for cat in cats if cat['id']            in catIds]
176         ids = [cat['id'] for cat in cats]
177         return ids
178
179     def getImgIds(self, imgIds=[], catIds=[]):
180         '''
181         Get img ids that satisfy given filter conditions.
182         :param imgIds (int array) : get imgs for given ids
183         :param catIds (int array) : get imgs with all given cats
184         :return: ids (int array)  : integer array of img ids
185         '''
186         imgIds = imgIds if _isArrayLike(imgIds) else [imgIds]
187         catIds = catIds if _isArrayLike(catIds) else [catIds]
188
189         if len(imgIds) == len(catIds) == 0:
190             ids = self.imgs.keys()
191         else:
192             ids = set(imgIds)
193             for i, catId in enumerate(catIds):
194                 if i == 0 and len(ids) == 0:
195                     ids = set(self.catToImgs[catId])
196                 else:
197                     ids &= set(self.catToImgs[catId])
198         return list(ids)
199
200     def loadAnns(self, ids=[]):
201         """
202         Load anns with the specified ids.
203         :param ids (int array)       : integer ids specifying anns
204         :return: anns (object array) : loaded ann objects
205         """
206         if _isArrayLike(ids):
207             return [self.anns[id] for id in ids]
208         elif type(ids) == int:
209             return [self.anns[ids]]
210
211     def loadCats(self, ids=[]):
212         """
213         Load cats with the specified ids.
214         :param ids (int array)       : integer ids specifying cats
215         :return: cats (object array) : loaded cat objects
216         """
217         if _isArrayLike(ids):
218             return [self.cats[id] for id in ids]
219         elif type(ids) == int:
220             return [self.cats[ids]]
221
222     def loadImgs(self, ids=[]):
223         """
224         Load anns with the specified ids.
225         :param ids (int array)       : integer ids specifying img
226         :return: imgs (object array) : loaded img objects
227         """
228         if _isArrayLike(ids):
229             return [self.imgs[id] for id in ids]
230         elif type(ids) == int:
231             return [self.imgs[ids]]
232
233     def showAnns(self, anns):
234         """
235         Display the specified annotations.
236         :param anns (array of object): annotations to display
237         :return: None
238         """
239         if len(anns) == 0:
240             return 0
241         if 'segmentation' in anns[0] or 'keypoints' in anns[0]:
242             datasetType = 'instances'
243         elif 'caption' in anns[0]:
244             datasetType = 'captions'
245         else:
246             raise Exception('datasetType not supported')
247         if datasetType == 'instances':
248             ax = plt.gca()
249             ax.set_autoscale_on(False)
250             polygons = []
251             color = []
252             for ann in anns:
253                 c = (np.random.random((1, 3))*0.6+0.4).tolist()[0]
254                 if 'segmentation' in ann:
255                     if type(ann['segmentation']) == list:
256                         # polygon
257                         for seg in ann['segmentation']:
258                             poly = np.array(seg).reshape((int(len(seg)/2), 2))
259                             polygons.append(Polygon(poly))
260                             color.append(c)
261                     else:
262                         # mask
263                         t = self.imgs[ann['image_id']]
264                         if type(ann['segmentation']['counts']) == list:
265                             rle = maskUtils.frPyObjects([ann['segmentation']], t['height'], t['width'])
266                         else:
267                             rle = [ann['segmentation']]
268                         m = maskUtils.decode(rle)
269                         img = np.ones( (m.shape[0], m.shape[1], 3) )
270                         if ann['iscrowd'] == 1:
271                             color_mask = np.array([2.0,166.0,101.0])/255
272                         if ann['iscrowd'] == 0:
273                             color_mask = np.random.random((1, 3)).tolist()[0]
274                         for i in range(3):
275                             img[:,:,i] = color_mask[i]
276                         ax.imshow(np.dstack( (img, m*0.5) ))
277                 if 'keypoints' in ann and type(ann['keypoints']) == list:
278                     # turn skeleton into zero-based index
279                     sks = np.array(self.loadCats(ann['category_id'])[0]['skeleton'])-1
280                     kp = np.array(ann['keypoints'])
281                     x = kp[0::3]
282                     y = kp[1::3]
283                     v = kp[2::3]
284                     for sk in sks:
285                         if np.all(v[sk]>0):
286                             plt.plot(x[sk],y[sk], linewidth=3, color=c)
287                     plt.plot(x[v>0], y[v>0],'o',markersize=8, markerfacecolor=c, markeredgecolor='k',markeredgewidth=2)
288                     plt.plot(x[v>1], y[v>1],'o',markersize=8, markerfacecolor=c, markeredgecolor=c, markeredgewidth=2)
289             p = PatchCollection(polygons, facecolor=color, linewidths=0, alpha=0.4)
290             ax.add_collection(p)
291             p = PatchCollection(polygons, facecolor='none', edgecolors=color, linewidths=2)
292             ax.add_collection(p)
293         elif datasetType == 'captions':
294             for ann in anns:
295                 print(ann['caption'])
296
297     def loadRes(self, resFile):
298         """
299         Load result file and return a result api object.
300         :param   resFile (str)     : file name of result file
301         :return: res (obj)         : result api object
302         """
303         res = COCO()
304         res.dataset['images'] = [img for img in self.dataset['images']]
305
306         print('Loading and preparing results...')
307         tic = time.time()
308         if type(resFile) == str or type(resFile) == unicode:
309             anns = json.load(open(resFile))
310         elif type(resFile) == np.ndarray:
311             anns = self.loadNumpyAnnotations(resFile)
312         else:
313             anns = resFile
314         assert type(anns) == list, 'results in not an array of objects'
315         annsImgIds = [ann['image_id'] for ann in anns]
316         assert set(annsImgIds) == (set(annsImgIds) & set(self.getImgIds())), \
317                'Results do not correspond to current coco set'
318         if 'caption' in anns[0]:
319             imgIds = set([img['id'] for img in res.dataset['images']]) & set([ann['image_id'] for ann in anns])
320             res.dataset['images'] = [img for img in res.dataset['images'] if img['id'] in imgIds]
321             for id, ann in enumerate(anns):
322                 ann['id'] = id+1
323         elif 'bbox' in anns[0] and not anns[0]['bbox'] == []:
324             res.dataset['categories'] = copy.deepcopy(self.dataset['categories'])
325             for id, ann in enumerate(anns):
326                 bb = ann['bbox']
327                 x1, x2, y1, y2 = [bb[0], bb[0]+bb[2], bb[1], bb[1]+bb[3]]
328                 if not 'segmentation' in ann:
329                     ann['segmentation'] = [[x1, y1, x1, y2, x2, y2, x2, y1]]
330                 ann['area'] = bb[2]*bb[3]
331                 ann['id'] = id+1
332                 ann['iscrowd'] = 0
333         elif 'segmentation' in anns[0]:
334             res.dataset['categories'] = copy.deepcopy(self.dataset['categories'])
335             for id, ann in enumerate(anns):
336                 # now only support compressed RLE format as segmentation results
337                 ann['area'] = maskUtils.area(ann['segmentation'])
338                 if not 'bbox' in ann:
339                     ann['bbox'] = maskUtils.toBbox(ann['segmentation'])
340                 ann['id'] = id+1
341                 ann['iscrowd'] = 0
342         elif 'keypoints' in anns[0]:
343             res.dataset['categories'] = copy.deepcopy(self.dataset['categories'])
344             for id, ann in enumerate(anns):
345                 s = ann['keypoints']
346                 x = s[0::3]
347                 y = s[1::3]
348                 x0,x1,y0,y1 = np.min(x), np.max(x), np.min(y), np.max(y)
349                 ann['area'] = (x1-x0)*(y1-y0)
350                 ann['id'] = id + 1
351                 ann['bbox'] = [x0,y0,x1-x0,y1-y0]
352         print('DONE (t={:0.2f}s)'.format(time.time()- tic))
353
354         res.dataset['annotations'] = anns
355         res.createIndex()
356         return res
357
358     def download(self, tarDir = None, imgIds = [] ):
359         '''
360         Download COCO images from mscoco.org server.
361         :param tarDir (str): COCO results directory name
362                imgIds (list): images to be downloaded
363         :return:
364         '''
365         if tarDir is None:
366             print('Please specify target directory')
367             return -1
368         if len(imgIds) == 0:
369             imgs = self.imgs.values()
370         else:
371             imgs = self.loadImgs(imgIds)
372         N = len(imgs)
373         if not os.path.exists(tarDir):
374             os.makedirs(tarDir)
375         for i, img in enumerate(imgs):
376             tic = time.time()
377             fname = os.path.join(tarDir, img['file_name'])
378             if not os.path.exists(fname):
379                 urlretrieve(img['coco_url'], fname)
380             print('downloaded {}/{} images (t={:0.1f}s)'.format(i, N, time.time()- tic))
381
382     def loadNumpyAnnotations(self, data):
383         """
384         Convert result data from a numpy array [Nx7] where each row contains {imageID,x1,y1,w,h,score,class}
385         :param  data (numpy.ndarray)
386         :return: annotations (python nested list)
387         """
388         print('Converting ndarray to lists...')
389         assert(type(data) == np.ndarray)
390         print(data.shape)
391         assert(data.shape[1] == 7)
392         N = data.shape[0]
393         ann = []
394         for i in range(N):
395             if i % 1000000 == 0:
396                 print('{}/{}'.format(i,N))
397             ann += [{
398                 'image_id'  : int(data[i, 0]),
399                 'bbox'  : [ data[i, 1], data[i, 2], data[i, 3], data[i, 4] ],
400                 'score' : data[i, 5],
401                 'category_id': int(data[i, 6]),
402                 }]
403         return ann
404
405     def annToRLE(self, ann):
406         """
407         Convert annotation which can be polygons, uncompressed RLE to RLE.
408         :return: binary mask (numpy 2D array)
409         """
410         t = self.imgs[ann['image_id']]
411         h, w = t['height'], t['width']
412         segm = ann['segmentation']
413         if type(segm) == list:
414             # polygon -- a single object might consist of multiple parts
415             # we merge all parts into one mask rle code
416             rles = maskUtils.frPyObjects(segm, h, w)
417             rle = maskUtils.merge(rles)
418         elif type(segm['counts']) == list:
419             # uncompressed RLE
420             rle = maskUtils.frPyObjects(segm, h, w)
421         else:
422             # rle
423             rle = ann['segmentation']
424         return rle
425
426     def annToMask(self, ann):
427         """
428         Convert annotation which can be polygons, uncompressed RLE, or RLE to binary mask.
429         :return: binary mask (numpy 2D array)
430         """
431         rle = self.annToRLE(ann)
432         m = maskUtils.decode(rle)
433         return m