--- /dev/null
+__author__ = 'tylin'
+__version__ = '2.0'
+# Interface for accessing the Microsoft COCO dataset.
+
+# Microsoft COCO is a large image dataset designed for object detection,
+# segmentation, and caption generation. pycocotools is a Python API that
+# assists in loading, parsing and visualizing the annotations in COCO.
+# Please visit http://mscoco.org/ for more information on COCO, including
+# for the data, paper, and tutorials. The exact format of the annotations
+# is also described on the COCO website. For example usage of the pycocotools
+# please see pycocotools_demo.ipynb. In addition to this API, please download both
+# the COCO images and annotations in order to run the demo.
+
+# An alternative to using the API is to load the annotations directly
+# into Python dictionary
+# Using the API provides additional utility functions. Note that this API
+# supports both *instance* and *caption* annotations. In the case of
+# captions not all functions are defined (e.g. categories are undefined).
+
+# The following API functions are defined:
+# COCO - COCO api class that loads COCO annotation file and prepare data structures.
+# decodeMask - Decode binary mask M encoded via run-length encoding.
+# encodeMask - Encode binary mask M using run-length encoding.
+# getAnnIds - Get ann ids that satisfy given filter conditions.
+# getCatIds - Get cat ids that satisfy given filter conditions.
+# getImgIds - Get img ids that satisfy given filter conditions.
+# loadAnns - Load anns with the specified ids.
+# loadCats - Load cats with the specified ids.
+# loadImgs - Load imgs with the specified ids.
+# annToMask - Convert segmentation in an annotation to binary mask.
+# showAnns - Display the specified annotations.
+# loadRes - Load algorithm results and create API for accessing them.
+# download - Download COCO images from mscoco.org server.
+# Throughout the API "ann"=annotation, "cat"=category, and "img"=image.
+# Help on each functions can be accessed by: "help COCO>function".
+
+# See also COCO>decodeMask,
+# COCO>encodeMask, COCO>getAnnIds, COCO>getCatIds,
+# COCO>getImgIds, COCO>loadAnns, COCO>loadCats,
+# COCO>loadImgs, COCO>annToMask, COCO>showAnns
+
+# Microsoft COCO Toolbox. version 2.0
+# Data, paper, and tutorials available at: http://mscoco.org/
+# Code written by Piotr Dollar and Tsung-Yi Lin, 2014.
+# Licensed under the Simplified BSD License [see bsd.txt]
+
+import json
+import time
+import matplotlib.pyplot as plt
+from matplotlib.collections import PatchCollection
+from matplotlib.patches import Polygon
+import numpy as np
+import copy
+import itertools
+from . import mask as maskUtils
+import os
+from collections import defaultdict
+import sys
+PYTHON_VERSION = sys.version_info[0]
+if PYTHON_VERSION == 2:
+ from urllib import urlretrieve
+elif PYTHON_VERSION == 3:
+ from urllib.request import urlretrieve
+
+
+def _isArrayLike(obj):
+ return hasattr(obj, '__iter__') and hasattr(obj, '__len__')
+
+
+class COCO:
+ def __init__(self, annotation_file=None):
+ """
+ Constructor of Microsoft COCO helper class for reading and visualizing annotations.
+ :param annotation_file (str): location of annotation file
+ :param image_folder (str): location to the folder that hosts images.
+ :return:
+ """
+ # load dataset
+ self.dataset,self.anns,self.cats,self.imgs = dict(),dict(),dict(),dict()
+ self.imgToAnns, self.catToImgs = defaultdict(list), defaultdict(list)
+ if not annotation_file == None:
+ print('loading annotations into memory...')
+ tic = time.time()
+ dataset = json.load(open(annotation_file, 'r'))
+ assert type(dataset)==dict, 'annotation file format {} not supported'.format(type(dataset))
+ print('Done (t={:0.2f}s)'.format(time.time()- tic))
+ self.dataset = dataset
+ self.createIndex()
+
+ def createIndex(self):
+ # create index
+ print('creating index...')
+ anns, cats, imgs = {}, {}, {}
+ imgToAnns,catToImgs = defaultdict(list),defaultdict(list)
+ if 'annotations' in self.dataset:
+ for ann in self.dataset['annotations']:
+ imgToAnns[ann['image_id']].append(ann)
+ anns[ann['id']] = ann
+
+ if 'images' in self.dataset:
+ for img in self.dataset['images']:
+ imgs[img['id']] = img
+
+ if 'categories' in self.dataset:
+ for cat in self.dataset['categories']:
+ cats[cat['id']] = cat
+
+ if 'annotations' in self.dataset and 'categories' in self.dataset:
+ for ann in self.dataset['annotations']:
+ catToImgs[ann['category_id']].append(ann['image_id'])
+
+ print('index created!')
+
+ # create class members
+ self.anns = anns
+ self.imgToAnns = imgToAnns
+ self.catToImgs = catToImgs
+ self.imgs = imgs
+ self.cats = cats
+
+ def info(self):
+ """
+ Print information about the annotation file.
+ :return:
+ """
+ for key, value in self.dataset['info'].items():
+ print('{}: {}'.format(key, value))
+
+ def getAnnIds(self, imgIds=[], catIds=[], areaRng=[], iscrowd=None):
+ """
+ Get ann ids that satisfy given filter conditions. default skips that filter
+ :param imgIds (int array) : get anns for given imgs
+ catIds (int array) : get anns for given cats
+ areaRng (float array) : get anns for given area range (e.g. [0 inf])
+ iscrowd (boolean) : get anns for given crowd label (False or True)
+ :return: ids (int array) : integer array of ann ids
+ """
+ imgIds = imgIds if _isArrayLike(imgIds) else [imgIds]
+ catIds = catIds if _isArrayLike(catIds) else [catIds]
+
+ if len(imgIds) == len(catIds) == len(areaRng) == 0:
+ anns = self.dataset['annotations']
+ else:
+ if not len(imgIds) == 0:
+ lists = [self.imgToAnns[imgId] for imgId in imgIds if imgId in self.imgToAnns]
+ anns = list(itertools.chain.from_iterable(lists))
+ else:
+ anns = self.dataset['annotations']
+ anns = anns if len(catIds) == 0 else [ann for ann in anns if ann['category_id'] in catIds]
+ anns = anns if len(areaRng) == 0 else [ann for ann in anns if ann['area'] > areaRng[0] and ann['area'] < areaRng[1]]
+ if not iscrowd == None:
+ ids = [ann['id'] for ann in anns if ann['iscrowd'] == iscrowd]
+ else:
+ ids = [ann['id'] for ann in anns]
+ return ids
+
+ def getCatIds(self, catNms=[], supNms=[], catIds=[]):
+ """
+ filtering parameters. default skips that filter.
+ :param catNms (str array) : get cats for given cat names
+ :param supNms (str array) : get cats for given supercategory names
+ :param catIds (int array) : get cats for given cat ids
+ :return: ids (int array) : integer array of cat ids
+ """
+ catNms = catNms if _isArrayLike(catNms) else [catNms]
+ supNms = supNms if _isArrayLike(supNms) else [supNms]
+ catIds = catIds if _isArrayLike(catIds) else [catIds]
+
+ if len(catNms) == len(supNms) == len(catIds) == 0:
+ cats = self.dataset['categories']
+ else:
+ cats = self.dataset['categories']
+ cats = cats if len(catNms) == 0 else [cat for cat in cats if cat['name'] in catNms]
+ cats = cats if len(supNms) == 0 else [cat for cat in cats if cat['supercategory'] in supNms]
+ cats = cats if len(catIds) == 0 else [cat for cat in cats if cat['id'] in catIds]
+ ids = [cat['id'] for cat in cats]
+ return ids
+
+ def getImgIds(self, imgIds=[], catIds=[]):
+ '''
+ Get img ids that satisfy given filter conditions.
+ :param imgIds (int array) : get imgs for given ids
+ :param catIds (int array) : get imgs with all given cats
+ :return: ids (int array) : integer array of img ids
+ '''
+ imgIds = imgIds if _isArrayLike(imgIds) else [imgIds]
+ catIds = catIds if _isArrayLike(catIds) else [catIds]
+
+ if len(imgIds) == len(catIds) == 0:
+ ids = self.imgs.keys()
+ else:
+ ids = set(imgIds)
+ for i, catId in enumerate(catIds):
+ if i == 0 and len(ids) == 0:
+ ids = set(self.catToImgs[catId])
+ else:
+ ids &= set(self.catToImgs[catId])
+ return list(ids)
+
+ def loadAnns(self, ids=[]):
+ """
+ Load anns with the specified ids.
+ :param ids (int array) : integer ids specifying anns
+ :return: anns (object array) : loaded ann objects
+ """
+ if _isArrayLike(ids):
+ return [self.anns[id] for id in ids]
+ elif type(ids) == int:
+ return [self.anns[ids]]
+
+ def loadCats(self, ids=[]):
+ """
+ Load cats with the specified ids.
+ :param ids (int array) : integer ids specifying cats
+ :return: cats (object array) : loaded cat objects
+ """
+ if _isArrayLike(ids):
+ return [self.cats[id] for id in ids]
+ elif type(ids) == int:
+ return [self.cats[ids]]
+
+ def loadImgs(self, ids=[]):
+ """
+ Load anns with the specified ids.
+ :param ids (int array) : integer ids specifying img
+ :return: imgs (object array) : loaded img objects
+ """
+ if _isArrayLike(ids):
+ return [self.imgs[id] for id in ids]
+ elif type(ids) == int:
+ return [self.imgs[ids]]
+
+ def showAnns(self, anns):
+ """
+ Display the specified annotations.
+ :param anns (array of object): annotations to display
+ :return: None
+ """
+ if len(anns) == 0:
+ return 0
+ if 'segmentation' in anns[0] or 'keypoints' in anns[0]:
+ datasetType = 'instances'
+ elif 'caption' in anns[0]:
+ datasetType = 'captions'
+ else:
+ raise Exception('datasetType not supported')
+ if datasetType == 'instances':
+ ax = plt.gca()
+ ax.set_autoscale_on(False)
+ polygons = []
+ color = []
+ for ann in anns:
+ c = (np.random.random((1, 3))*0.6+0.4).tolist()[0]
+ if 'segmentation' in ann:
+ if type(ann['segmentation']) == list:
+ # polygon
+ for seg in ann['segmentation']:
+ poly = np.array(seg).reshape((int(len(seg)/2), 2))
+ polygons.append(Polygon(poly))
+ color.append(c)
+ else:
+ # mask
+ t = self.imgs[ann['image_id']]
+ if type(ann['segmentation']['counts']) == list:
+ rle = maskUtils.frPyObjects([ann['segmentation']], t['height'], t['width'])
+ else:
+ rle = [ann['segmentation']]
+ m = maskUtils.decode(rle)
+ img = np.ones( (m.shape[0], m.shape[1], 3) )
+ if ann['iscrowd'] == 1:
+ color_mask = np.array([2.0,166.0,101.0])/255
+ if ann['iscrowd'] == 0:
+ color_mask = np.random.random((1, 3)).tolist()[0]
+ for i in range(3):
+ img[:,:,i] = color_mask[i]
+ ax.imshow(np.dstack( (img, m*0.5) ))
+ if 'keypoints' in ann and type(ann['keypoints']) == list:
+ # turn skeleton into zero-based index
+ sks = np.array(self.loadCats(ann['category_id'])[0]['skeleton'])-1
+ kp = np.array(ann['keypoints'])
+ x = kp[0::3]
+ y = kp[1::3]
+ v = kp[2::3]
+ for sk in sks:
+ if np.all(v[sk]>0):
+ plt.plot(x[sk],y[sk], linewidth=3, color=c)
+ plt.plot(x[v>0], y[v>0],'o',markersize=8, markerfacecolor=c, markeredgecolor='k',markeredgewidth=2)
+ plt.plot(x[v>1], y[v>1],'o',markersize=8, markerfacecolor=c, markeredgecolor=c, markeredgewidth=2)
+ p = PatchCollection(polygons, facecolor=color, linewidths=0, alpha=0.4)
+ ax.add_collection(p)
+ p = PatchCollection(polygons, facecolor='none', edgecolors=color, linewidths=2)
+ ax.add_collection(p)
+ elif datasetType == 'captions':
+ for ann in anns:
+ print(ann['caption'])
+
+ def loadRes(self, resFile):
+ """
+ Load result file and return a result api object.
+ :param resFile (str) : file name of result file
+ :return: res (obj) : result api object
+ """
+ res = COCO()
+ res.dataset['images'] = [img for img in self.dataset['images']]
+
+ print('Loading and preparing results...')
+ tic = time.time()
+ if type(resFile) == str or type(resFile) == unicode:
+ anns = json.load(open(resFile))
+ elif type(resFile) == np.ndarray:
+ anns = self.loadNumpyAnnotations(resFile)
+ else:
+ anns = resFile
+ assert type(anns) == list, 'results in not an array of objects'
+ annsImgIds = [ann['image_id'] for ann in anns]
+ assert set(annsImgIds) == (set(annsImgIds) & set(self.getImgIds())), \
+ 'Results do not correspond to current coco set'
+ if 'caption' in anns[0]:
+ imgIds = set([img['id'] for img in res.dataset['images']]) & set([ann['image_id'] for ann in anns])
+ res.dataset['images'] = [img for img in res.dataset['images'] if img['id'] in imgIds]
+ for id, ann in enumerate(anns):
+ ann['id'] = id+1
+ elif 'bbox' in anns[0] and not anns[0]['bbox'] == []:
+ res.dataset['categories'] = copy.deepcopy(self.dataset['categories'])
+ for id, ann in enumerate(anns):
+ bb = ann['bbox']
+ x1, x2, y1, y2 = [bb[0], bb[0]+bb[2], bb[1], bb[1]+bb[3]]
+ if not 'segmentation' in ann:
+ ann['segmentation'] = [[x1, y1, x1, y2, x2, y2, x2, y1]]
+ ann['area'] = bb[2]*bb[3]
+ ann['id'] = id+1
+ ann['iscrowd'] = 0
+ elif 'segmentation' in anns[0]:
+ res.dataset['categories'] = copy.deepcopy(self.dataset['categories'])
+ for id, ann in enumerate(anns):
+ # now only support compressed RLE format as segmentation results
+ ann['area'] = maskUtils.area(ann['segmentation'])
+ if not 'bbox' in ann:
+ ann['bbox'] = maskUtils.toBbox(ann['segmentation'])
+ ann['id'] = id+1
+ ann['iscrowd'] = 0
+ elif 'keypoints' in anns[0]:
+ res.dataset['categories'] = copy.deepcopy(self.dataset['categories'])
+ for id, ann in enumerate(anns):
+ s = ann['keypoints']
+ x = s[0::3]
+ y = s[1::3]
+ x0,x1,y0,y1 = np.min(x), np.max(x), np.min(y), np.max(y)
+ ann['area'] = (x1-x0)*(y1-y0)
+ ann['id'] = id + 1
+ ann['bbox'] = [x0,y0,x1-x0,y1-y0]
+ print('DONE (t={:0.2f}s)'.format(time.time()- tic))
+
+ res.dataset['annotations'] = anns
+ res.createIndex()
+ return res
+
+ def download(self, tarDir = None, imgIds = [] ):
+ '''
+ Download COCO images from mscoco.org server.
+ :param tarDir (str): COCO results directory name
+ imgIds (list): images to be downloaded
+ :return:
+ '''
+ if tarDir is None:
+ print('Please specify target directory')
+ return -1
+ if len(imgIds) == 0:
+ imgs = self.imgs.values()
+ else:
+ imgs = self.loadImgs(imgIds)
+ N = len(imgs)
+ if not os.path.exists(tarDir):
+ os.makedirs(tarDir)
+ for i, img in enumerate(imgs):
+ tic = time.time()
+ fname = os.path.join(tarDir, img['file_name'])
+ if not os.path.exists(fname):
+ urlretrieve(img['coco_url'], fname)
+ print('downloaded {}/{} images (t={:0.1f}s)'.format(i, N, time.time()- tic))
+
+ def loadNumpyAnnotations(self, data):
+ """
+ Convert result data from a numpy array [Nx7] where each row contains {imageID,x1,y1,w,h,score,class}
+ :param data (numpy.ndarray)
+ :return: annotations (python nested list)
+ """
+ print('Converting ndarray to lists...')
+ assert(type(data) == np.ndarray)
+ print(data.shape)
+ assert(data.shape[1] == 7)
+ N = data.shape[0]
+ ann = []
+ for i in range(N):
+ if i % 1000000 == 0:
+ print('{}/{}'.format(i,N))
+ ann += [{
+ 'image_id' : int(data[i, 0]),
+ 'bbox' : [ data[i, 1], data[i, 2], data[i, 3], data[i, 4] ],
+ 'score' : data[i, 5],
+ 'category_id': int(data[i, 6]),
+ }]
+ return ann
+
+ def annToRLE(self, ann):
+ """
+ Convert annotation which can be polygons, uncompressed RLE to RLE.
+ :return: binary mask (numpy 2D array)
+ """
+ t = self.imgs[ann['image_id']]
+ h, w = t['height'], t['width']
+ segm = ann['segmentation']
+ if type(segm) == list:
+ # polygon -- a single object might consist of multiple parts
+ # we merge all parts into one mask rle code
+ rles = maskUtils.frPyObjects(segm, h, w)
+ rle = maskUtils.merge(rles)
+ elif type(segm['counts']) == list:
+ # uncompressed RLE
+ rle = maskUtils.frPyObjects(segm, h, w)
+ else:
+ # rle
+ rle = ann['segmentation']
+ return rle
+
+ def annToMask(self, ann):
+ """
+ Convert annotation which can be polygons, uncompressed RLE, or RLE to binary mask.
+ :return: binary mask (numpy 2D array)
+ """
+ rle = self.annToRLE(ann)
+ m = maskUtils.decode(rle)
+ return m
\ No newline at end of file