coco_label

import json
import os
import numpy as np

np.set_printoptions(suppress=True)
json_dir = '/home/wxrui/DATA/coco/coco/annotations'
out_dir = '/home/wxrui/DATA/coco/test'
type_list = ['train', 'val']
year_list = ['2014', '2017']
label_map = {}


def create_label_map(data_dic):
    global label_map
    label_map = data_dic['categories']
    for index, item in enumerate(label_map):
        item['index'] = index
    label_map_file = open(os.path.join(out_dir, 'label_map.json'), 'w')
    json.dump(label_map, label_map_file, indent=4)


def id_2_index(id_list):
    index_list = []
    map_list = [0 for i in range(100)]
    for item in label_map:
        index = item['index']
        id = item['id']
        map_list[id] = index
    for id in id_list:
        index_list.append(map_list[int(id)])
    return np.array(index_list)


def format_label(label, img_width, img_height):
    centerx = label[:, 1] + label[:, 3] / 2
    centery = label[:, 2] + label[:, 4] / 2
    width = label[:, 3]
    height = label[:, 4]
    index = id_2_index(label[:, 0])
    label = np.column_stack([index, centerx, centery, width, height])
    label /= np.array([[1, img_width, img_height, img_width, img_height]])
    return label


for type in type_list:
    for year in year_list:
        path = os.path.join(out_dir, type + year)
        # mkdir
        if not os.path.exists(path):
            os.mkdir(path)
        # read data
        json_file = os.path.join(json_dir, 'instances_%s%s.json' % (type, year))
        data_dic = json.load(open(json_file))
        if type == 'train' and year == '2014':
            # labelmap
            print('creating labelmap')
            create_label_map(data_dic)
            print('created labelmap')
        # annotations
        annotations = data_dic['annotations']
        anno_dict = {}
        for anno in annotations:
            image_id = str(anno['image_id'])
            bbox = anno['bbox']
            label = [anno['category_id']]
            label.extend(bbox)
            if image_id in anno_dict.keys():
                anno_dict[image_id].append(label)
            else:
                anno_dict[image_id] = [label]
        # images
        images = data_dic['images']
        for image in images:
            image_name = image['file_name']
            image_id = image['id']
            width = image['width']
            height = image['height']
            if not str(image_id) in anno_dict.keys():
                continue
            label = np.array(anno_dict[str(image_id)])
            label = format_label(label=label, img_width=width, img_height=height)
            # write to disk
            file_name = os.path.join(path, image_name).replace('jpg', 'txt')
            file = open(file_name, 'w')
            for line in label:
                file.writelines(str(line).strip('[').strip(']') + '\n')
            print(image_name + '  finished!')

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章