在完成了caffe的配置後,以及安裝完依賴庫cython, opencv, pyyaml, easydict
這裏首先記錄一下easydict的錯誤(再次強調,這個庫一定要裝低版本!),在通過pip install easydict後可解決
File "./tools/generate_tsv.py", line 221, in <module> assert cfg.TEST.HAS_RPN
Assertion error: cfg.TEST.HAS_RPN == False
下面進入正題,genome features的提取
- 下載好visual-genome的images,分爲兩組VG_100K與VG_100K_2,兩組的image_id無重複,全部放入bottom-up-attention/data/VG_data目錄下(這裏附上數據鏈接http://visualgenome.org/)
- 對tools/generate_tsv稍作修改,主要更改58-65行的directory即可,這裏我的更改如下:
with open('./data/visualgenome/image_data.json') as f: for item in json.load(f): image_id = int(item['image_id']) # filepath = os.path.join('./data/VGdata/', item['url'].split('rak248/')[-1]) filepath = os.path.join('./data/VGdata/', str(image_id)+'.jpg') # 這裏可直接用作者的那句話 # print(filepath, os.path.exits(filepath)) split.append((filepath,image_id))
- 按照作者給的例子執行代碼即可,超參數給出如下,採用作者給出的pretrained_model:
python ./tools/generate_tsv.py --gpu 0 --cfg experiments/cfgs/faster_rcnn_end2end_resnet.yml --def ./models/vg/ResNet-101/faster_rcnn_end2end_final/test.prototxt --out /home/share/bierone/genome_resnet101_faster_rcnn_genome.tsv --net data/faster_rcnn_models/resnet101_faster_rcnn_final.caffemodel --split genome
驗證:提取完tsv文件後,難免需要進行檢查,驗證box的位置是否合理,這裏附上本人的代碼show.py(難免有疏漏之處,希望大家不吝指出):
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
# set display defaults
# plt.rcParams['figure.figsize'] = (10, 10) # large images
# plt.rcParams['image.interpolation'] = 'nearest' # don't interpolate: show square pixels
# plt.rcParams['image.cmap'] = 'gray' # use grayscale output rather than a (potentially misleading) color heatmap
import numpy as np
import cv2, base64
import csv, sys
csv.field_size_limit(sys.maxsize)
FIELDNAMES = ['image_id', 'image_w','image_h','num_boxes', 'boxes', 'features']
infile = '/home/share/lyb/genome_resnet101_faster_rcnn_genome.tsv'
data_root = '/home/lyb/bottom-up-attention/data/VGdata/'
def get_detections_from_tsv(nums=5):
in_data = {}
with open(infile, "r") as tsv_in_file:
reader = csv.DictReader(tsv_in_file, delimiter='\t', fieldnames = FIELDNAMES)
for i, item in enumerate(reader):
item['image_id'] = (item['image_id'])
item['image_h'] = int(item['image_h'])
item['image_w'] = int(item['image_w'])
item['num_boxes'] = int(item['num_boxes'])
for field in ['boxes', 'features']:
item[field] = np.frombuffer(base64.decodestring(item[field].encode('utf8')),
dtype=np.float32).reshape((item['num_boxes'],-1))
# show_features(item['boxes'])
in_data[i] = item
if i > nums:
break
return in_data
def show_features(ax, boxes, objects='aa', attrs='bb'):
for i in range(boxes.shape[0]):
bbox = boxes[i]
ax.add_patch(
plt.Rectangle((bbox[0], bbox[1]),
bbox[2] - bbox[0],
bbox[3] - bbox[1], fill=False,
edgecolor='red', linewidth=2, alpha=0.8)
)
# plt.gca().text(bbox[0], bbox[1] - 2,
# '%s' % (cls),
# bbox=dict(facecolor='blue', alpha=0.5),
# fontsize=12, color='white')
plt.axis('off')
# plt.tight_layout()
plt.draw()
if __name__ == '__main__':
in_data = get_detections_from_tsv()
for key,item in in_data.items():
# print(item)
im_file = data_root + item['image_id'] + '.jpg'
im = cv2.imread(im_file)
# im = im[:, :, (2, 1, 0)] # RGB reverse channels
fig, ax = plt.subplots(figsize=(20, 20))
ax.imshow(im)
show_features(ax,item['boxes'])
# rgb = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
plt.savefig('demo/'+item['image_id']+'.jpg')
總結:這裏作者的代碼寫的比較複雜,我只針對部分做了仔細查看,就不附上分析了。其實整個提取過程並不複雜,時間花費主要在配置環境上。