visual-genome rcnn features 提取(二)- 提取篇

在完成了caffe的配置後,以及安裝完依賴庫cython, opencv, pyyaml, easydict

這裏首先記錄一下easydict的錯誤(再次強調,這個庫一定要裝低版本!),在通過pip install easydict後可解決

File "./tools/generate_tsv.py", line 221, in <module> assert cfg.TEST.HAS_RPN
Assertion error: cfg.TEST.HAS_RPN == False

下面進入正題,genome features的提取

  1. 下載好visual-genome的images,分爲兩組VG_100K與VG_100K_2,兩組的image_id無重複,全部放入bottom-up-attention/data/VG_data目錄下(這裏附上數據鏈接http://visualgenome.org/
     
  2. 對tools/generate_tsv稍作修改,主要更改58-65行的directory即可,這裏我的更改如下:
    with open('./data/visualgenome/image_data.json') as f:
        for item in json.load(f):
            image_id = int(item['image_id'])
    		# filepath = os.path.join('./data/VGdata/', item['url'].split('rak248/')[-1])
    		filepath = os.path.join('./data/VGdata/', str(image_id)+'.jpg') # 這裏可直接用作者的那句話
    		# print(filepath, os.path.exits(filepath))
    		split.append((filepath,image_id))

     

  3. 按照作者給的例子執行代碼即可,超參數給出如下,採用作者給出的pretrained_model:
    python ./tools/generate_tsv.py --gpu 0 --cfg experiments/cfgs/faster_rcnn_end2end_resnet.yml --def ./models/vg/ResNet-101/faster_rcnn_end2end_final/test.prototxt --out /home/share/bierone/genome_resnet101_faster_rcnn_genome.tsv --net data/faster_rcnn_models/resnet101_faster_rcnn_final.caffemodel --split genome

     

驗證:提取完tsv文件後,難免需要進行檢查,驗證box的位置是否合理,這裏附上本人的代碼show.py(難免有疏漏之處,希望大家不吝指出):

import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
# set display defaults
# plt.rcParams['figure.figsize'] = (10, 10)		   # large images
# plt.rcParams['image.interpolation'] = 'nearest'  # don't interpolate: show square pixels
# plt.rcParams['image.cmap'] = 'gray'  # use grayscale output rather than a (potentially misleading) color heatmap
import numpy as np
import cv2, base64
import csv, sys
csv.field_size_limit(sys.maxsize)

FIELDNAMES = ['image_id', 'image_w','image_h','num_boxes', 'boxes', 'features']
infile = '/home/share/lyb/genome_resnet101_faster_rcnn_genome.tsv'
data_root = '/home/lyb/bottom-up-attention/data/VGdata/'


def get_detections_from_tsv(nums=5):
	in_data = {}
	with open(infile, "r") as tsv_in_file:
		reader = csv.DictReader(tsv_in_file, delimiter='\t', fieldnames = FIELDNAMES)
		for i, item in enumerate(reader):
			item['image_id'] = (item['image_id'])
			item['image_h'] = int(item['image_h'])
			item['image_w'] = int(item['image_w'])	 
			item['num_boxes'] = int(item['num_boxes'])
			for field in ['boxes', 'features']:
				item[field] = np.frombuffer(base64.decodestring(item[field].encode('utf8')), 
					  dtype=np.float32).reshape((item['num_boxes'],-1))
			# show_features(item['boxes'])
			in_data[i] = item
			if i > nums:
				break
	return in_data

def show_features(ax, boxes, objects='aa', attrs='bb'):
	for i in range(boxes.shape[0]):
		bbox = boxes[i]
		ax.add_patch(
			plt.Rectangle((bbox[0], bbox[1]),
						  bbox[2] - bbox[0],
						  bbox[3] - bbox[1], fill=False,
						  edgecolor='red', linewidth=2, alpha=0.8)
				)
		# plt.gca().text(bbox[0], bbox[1] - 2,
					# '%s' % (cls),
					# bbox=dict(facecolor='blue', alpha=0.5),
					# fontsize=12, color='white')
		plt.axis('off')
		# plt.tight_layout()
		plt.draw()


if __name__ == '__main__':
	in_data = get_detections_from_tsv()
	for key,item in in_data.items():
		# print(item)
		im_file = data_root + item['image_id'] + '.jpg'
		im = cv2.imread(im_file)
		# im = im[:, :, (2, 1, 0)] # RGB reverse channels
		fig, ax = plt.subplots(figsize=(20, 20))
		ax.imshow(im)
		show_features(ax,item['boxes'])
		# rgb = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
		plt.savefig('demo/'+item['image_id']+'.jpg')

 

總結:這裏作者的代碼寫的比較複雜,我只針對部分做了仔細查看,就不附上分析了。其實整個提取過程並不複雜,時間花費主要在配置環境上。

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章