學習筆記|Pytorch使用教程33
本學習筆記主要摘自“深度之眼”,做一個總結,方便查閱。
使用Pytorch版本爲1.2
- 圖像目標檢測是什麼?
- 模型是如何完成目標檢測的?
- 深度學習目標檢測模型簡介
- PyTorch中的Faster RCNN訓練
一.圖像目標檢測是什麼?
目標檢測:判斷圖像中目標的位置
目標檢測兩要素
- 1.分類:分類向量[p0, … pn]
- 2.迴歸:迴歸邊界框[x1, y1, x2, y2]
測試代碼:
import os
import time
import torch.nn as nn
import torch
import numpy as np
import torchvision.transforms as transforms
import torchvision
from PIL import Image
from matplotlib import pyplot as plt
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# classes_coco
COCO_INSTANCE_CATEGORY_NAMES = [
'__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign',
'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A',
'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
'bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table',
'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book',
'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'
]
if __name__ == "__main__":
# path_img = os.path.join(BASE_DIR, "demo_img1.png")
path_img = os.path.join(BASE_DIR, "demo_img2.png")
# config
preprocess = transforms.Compose([
transforms.ToTensor(),
])
# 1. load data & model
input_image = Image.open(path_img).convert("RGB")
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
model.eval()
# 2. preprocess
img_chw = preprocess(input_image)
# 3. to device
if torch.cuda.is_available():
img_chw = img_chw.to('cuda')
model.to('cuda')
# 4. forward
input_list = [img_chw]
with torch.no_grad():
tic = time.time()
print("input img tensor shape:{}".format(input_list[0].shape))
output_list = model(input_list)
output_dict = output_list[0]
print("pass: {:.3f}s".format(time.time() - tic))
for k, v in output_dict.items():
print("key:{}, value:{}".format(k, v))
# 5. visualization
out_boxes = output_dict["boxes"].cpu()
out_scores = output_dict["scores"].cpu()
out_labels = output_dict["labels"].cpu()
fig, ax = plt.subplots(figsize=(12, 12))
ax.imshow(input_image, aspect='equal')
num_boxes = out_boxes.shape[0]
max_vis = 40
thres = 0.5
for idx in range(0, min(num_boxes, max_vis)):
score = out_scores[idx].numpy()
bbox = out_boxes[idx].numpy()
class_name = COCO_INSTANCE_CATEGORY_NAMES[out_labels[idx]]
if score < thres:
continue
ax.add_patch(plt.Rectangle((bbox[0], bbox[1]), bbox[2] - bbox[0], bbox[3] - bbox[1], fill=False,
edgecolor='red', linewidth=3.5))
ax.text(bbox[0], bbox[1] - 2, '{:s} {:.3f}'.format(class_name, score), bbox=dict(facecolor='blue', alpha=0.5),
fontsize=14, color='white')
plt.show()
plt.close()
# appendix
classes_pascal_voc = ['__background__',
'aeroplane', 'bicycle', 'bird', 'boat',
'bottle', 'bus', 'car', 'cat', 'chair',
'cow', 'diningtable', 'dog', 'horse',
'motorbike', 'person', 'pottedplant',
'sheep', 'sofa', 'train', 'tvmonitor']
# classes_coco
COCO_INSTANCE_CATEGORY_NAMES = [
'__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign',
'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A',
'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
'bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table',
'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book',
'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'
]
輸出:
input img tensor shape:torch.Size([3, 624, 1270])
pass: 13.661s
key:boxes, value:tensor([[2.1437e+01, 4.0840e+02, 5.6342e+01, 5.3993e+02],
[2.7507e+02, 4.1659e+02, 3.1846e+02, 5.2799e+02],
[3.3170e+02, 5.0658e+02, 3.8219e+02, 6.2113e+02],
[1.0627e+03, 5.6276e+02, 1.1684e+03, 6.2371e+02],
[8.8013e+02, 5.0102e+02, 9.3208e+02, 6.2317e+02],
[2.9642e+02, 5.2642e+02, 3.4381e+02, 6.2200e+02],
[1.5379e+02, 3.9273e+02, 1.9051e+02, 4.7901e+02],
[5.2459e+02, 5.5500e+02, 5.9428e+02, 6.2307e+02],
[4.3968e+02, 4.7425e+02, 4.9720e+02, 6.1554e+02],
[9.6592e+02, 4.4677e+02, 1.0049e+03, 5.7215e+02],
[1.0311e+03, 4.7703e+02, 1.0741e+03, 6.1917e+02],
[7.1520e+02, 5.5527e+02, 7.6435e+02, 6.2250e+02],
[1.9180e+02, 3.9129e+02, 2.1840e+02, 4.5502e+02],
[5.9519e+02, 5.6863e+02, 6.5838e+02, 6.2400e+02],
[9.2346e+02, 4.2539e+02, 9.6890e+02, 5.4164e+02],
[8.4545e+02, 4.2685e+02, 8.8473e+02, 5.3350e+02],
[5.5792e-01, 3.6247e+02, 1.9292e+01, 4.2037e+02],
[7.8786e+02, 4.5473e+02, 8.3009e+02, 5.5746e+02],
[5.9756e+02, 4.3980e+02, 6.4331e+02, 5.7260e+02],
[7.5372e+02, 5.4176e+02, 8.4086e+02, 6.2388e+02],
[1.0174e+03, 5.0093e+02, 1.0505e+03, 5.4634e+02],
[6.8192e+02, 5.3671e+02, 7.2875e+02, 6.2382e+02],
[8.1197e+02, 4.2305e+02, 8.4461e+02, 5.3224e+02],
[7.5444e+02, 3.9091e+02, 7.9017e+02, 4.9864e+02],
[5.3107e+02, 3.9075e+02, 5.6285e+02, 4.8725e+02],
[1.1842e+03, 5.6935e+02, 1.2687e+03, 6.2372e+02],
[9.0154e+02, 4.5109e+02, 9.1972e+02, 4.7041e+02],
[8.9092e+02, 4.1312e+02, 9.2181e+02, 5.0889e+02],
[4.9160e+02, 4.8394e+02, 5.1212e+02, 5.2896e+02],
[7.1178e+02, 4.7320e+02, 7.4839e+02, 5.6364e+02],
[1.1422e+03, 4.1846e+02, 1.1851e+03, 5.2725e+02],
[1.1044e+03, 4.1391e+02, 1.1432e+03, 5.1564e+02],
[4.8151e+02, 5.2476e+02, 4.9801e+02, 5.6309e+02],
[9.6673e+02, 4.7047e+02, 9.9382e+02, 5.1017e+02],
[1.5301e+02, 4.0614e+02, 1.7877e+02, 4.3976e+02],
[3.3971e+02, 3.4014e+02, 3.6640e+02, 4.1000e+02],
[1.1215e+01, 3.0503e+02, 2.5390e+01, 3.4648e+02],
[5.6783e+02, 4.4656e+02, 6.0336e+02, 5.6767e+02],
[1.0671e+03, 4.0842e+02, 1.1084e+03, 5.1374e+02],
[7.0506e+02, 4.0975e+02, 7.3976e+02, 4.9427e+02],
[1.1736e+03, 4.1151e+02, 1.2080e+03, 5.2507e+02],
[2.5137e+02, 3.2949e+02, 2.7344e+02, 3.9106e+02],
[1.6691e+02, 2.8140e+02, 1.8110e+02, 3.1285e+02],
[3.0369e+02, 4.6951e+02, 3.1904e+02, 5.0108e+02],
[1.3483e+02, 3.2507e+02, 1.5080e+02, 3.6978e+02],
[1.0107e+03, 4.4341e+02, 1.0458e+03, 5.4823e+02],
[9.8960e+02, 3.7219e+02, 1.0161e+03, 4.4772e+02],
[8.3098e+02, 3.9532e+02, 8.5813e+02, 4.6472e+02],
[6.6482e+02, 4.5117e+02, 6.8733e+02, 4.8071e+02],
[3.6000e+02, 3.9332e+02, 3.8890e+02, 4.8298e+02],
[1.0539e+03, 5.0596e+02, 1.0702e+03, 5.3395e+02],
[4.6973e+02, 4.5248e+02, 5.0838e+02, 5.7020e+02],
[1.5856e+02, 3.3735e+02, 1.7752e+02, 3.7869e+02],
[1.5349e+02, 4.0826e+02, 1.7287e+02, 4.4103e+02],
[3.8870e+02, 3.7187e+02, 4.2914e+02, 5.0503e+02],
[9.6698e+02, 4.7291e+02, 9.9026e+02, 5.0854e+02],
[5.5847e+02, 3.8495e+02, 5.8719e+02, 4.8738e+02],
[4.9743e+02, 3.8822e+02, 5.2440e+02, 4.8499e+02],
[6.0820e+01, 2.8248e+02, 7.6256e+01, 3.1529e+02],
[6.8791e+02, 4.9479e+02, 7.2739e+02, 5.5288e+02],
[6.5066e+02, 4.9294e+02, 7.0139e+02, 6.1894e+02],
[2.0727e+02, 3.9253e+02, 2.2674e+02, 4.5656e+02],
[3.3184e+02, 3.0833e+02, 3.4600e+02, 3.4902e+02],
[1.0159e+03, 4.9606e+02, 1.0568e+03, 5.4730e+02],
[6.0135e+01, 3.0983e+02, 7.7837e+01, 3.4432e+02],
[6.3866e+02, 4.2136e+02, 6.7227e+02, 5.2106e+02],
[4.6559e+02, 3.9241e+02, 4.8686e+02, 4.2766e+02],
[5.6188e+01, 3.1174e+02, 7.0932e+01, 3.4426e+02],
[4.3119e+02, 3.2678e+02, 4.6984e+02, 3.9337e+02],
[5.9947e+02, 3.9388e+02, 6.3010e+02, 4.5644e+02],
[1.1757e+03, 5.3985e+02, 1.2376e+03, 6.1111e+02],
[6.6622e+02, 4.1731e+02, 7.0462e+02, 4.9162e+02],
[1.7327e+02, 3.9034e+02, 1.9244e+02, 4.6006e+02],
[4.7853e+02, 4.7409e+02, 5.1007e+02, 5.2950e+02],
[2.8340e+02, 3.0262e+02, 2.9943e+02, 3.3402e+02],
[7.4611e+02, 3.5429e+02, 7.7780e+02, 4.1244e+02],
[7.4060e+02, 4.8190e+02, 7.6980e+02, 5.5989e+02],
[9.6401e+02, 3.4509e+02, 9.8135e+02, 3.9366e+02],
[4.1680e+02, 3.6824e+02, 4.4263e+02, 4.5963e+02],
[8.7578e+02, 3.5269e+02, 8.9966e+02, 4.2111e+02],
[1.0104e+03, 4.4886e+02, 1.0562e+03, 6.0378e+02],
[3.0327e+02, 4.4166e+02, 3.1852e+02, 5.0216e+02],
[4.4137e+02, 5.9189e+02, 4.8637e+02, 6.2351e+02],
[1.9031e+02, 3.3614e+02, 2.1258e+02, 3.8856e+02],
[1.8251e+02, 2.8241e+02, 1.9433e+02, 3.0937e+02],
[1.2041e+03, 4.6998e+02, 1.2527e+03, 5.5469e+02],
[1.0764e+03, 4.9217e+02, 1.1277e+03, 5.8620e+02],
[1.0449e+03, 3.4939e+02, 1.0654e+03, 4.0458e+02],
[1.0922e+03, 3.8381e+02, 1.1198e+03, 4.3049e+02],
[5.1150e+02, 3.8435e+02, 5.3621e+02, 4.8211e+02],
[3.1652e+02, 3.1374e+02, 3.3143e+02, 3.5101e+02],
[9.4753e+02, 3.4059e+02, 9.6551e+02, 3.9753e+02],
[5.1159e+02, 3.4381e+02, 5.3470e+02, 3.9003e+02],
[5.8443e+02, 3.9833e+02, 6.1497e+02, 4.8086e+02],
[7.4492e+02, 3.8018e+02, 7.6006e+02, 4.0452e+02],
[1.1097e+03, 3.0709e+02, 1.1257e+03, 3.4568e+02],
[6.6792e+02, 3.3659e+02, 6.8839e+02, 3.8192e+02],
[3.0073e+02, 3.0162e+02, 3.1883e+02, 3.4768e+02],
[1.0730e+03, 4.9895e+02, 1.1603e+03, 6.1907e+02],
[3.9530e+02, 4.2951e+02, 4.2193e+02, 4.6254e+02]])
key:labels, value:tensor([ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 31, 1, 1, 1, 1, 1, 31, 1, 27, 1, 1, 1, 31, 27, 27, 1,
1, 1, 1, 1, 1, 1, 1, 31, 1, 1, 1, 1, 31, 1, 31, 1, 1, 31,
1, 31, 1, 1, 1, 1, 1, 1, 1, 27, 1, 1, 31, 1, 1, 1, 1, 1,
1, 27, 1, 1, 1, 1, 1, 1, 1, 31, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 31, 1, 1, 1, 1, 31])
key:scores, value:tensor([0.9860, 0.9852, 0.9780, 0.9779, 0.9774, 0.9739, 0.9500, 0.9464, 0.9456,
0.9088, 0.8751, 0.8721, 0.8549, 0.8533, 0.8455, 0.8064, 0.8006, 0.7799,
0.7588, 0.7488, 0.7124, 0.7113, 0.6831, 0.6695, 0.6562, 0.6551, 0.6532,
0.6498, 0.6471, 0.6365, 0.6178, 0.5983, 0.5870, 0.5829, 0.5744, 0.5698,
0.5638, 0.5590, 0.5522, 0.5413, 0.5313, 0.5283, 0.5203, 0.4811, 0.4558,
0.4536, 0.4442, 0.4402, 0.4374, 0.4368, 0.4313, 0.4210, 0.4119, 0.4099,
0.3986, 0.3920, 0.3912, 0.3827, 0.3754, 0.3654, 0.3584, 0.3502, 0.3496,
0.3414, 0.3399, 0.3283, 0.3225, 0.3126, 0.3124, 0.3101, 0.3049, 0.3025,
0.3005, 0.2963, 0.2946, 0.2830, 0.2799, 0.2790, 0.2783, 0.2782, 0.2772,
0.2759, 0.2711, 0.2684, 0.2643, 0.2574, 0.2509, 0.2462, 0.2401, 0.2385,
0.2353, 0.2311, 0.2245, 0.2233, 0.2224, 0.2205, 0.2173, 0.2157, 0.2141,
0.2109])
Debug分析一下整個流程:
- 1.獲取圖片:
path_img = os.path.join(BASE_DIR, "demo_img2.png")
- 2.加載模型
torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
,並設置成測試模式 - 3.把數據(圖片)處理成模型輸入的格式(張量):
img_chw = preprocess(input_image)
,其shape爲:input img tensor shape:torch.Size([3, 624, 1270])
- 4.前向傳播:
output_list = model(input_list)
查看output_list:
這裏只使用了一張圖像,所以len(output_list) = 1。每一個字典都有三部分組成:boxes,labels,scores - 5.獲取第一個張圖的輸出結果
output_dict
。分別查看其屬性:
也就是輸出了100個boxes,對應每個boxes都有應該score和label。 - 6.把結果分別保存到
out_boxes
、out_scores
、out_labels
。打印out_scores
:
發現out_scores
是按順序排列的,所以爲了保證效果,只取得分較高的boxes。這就是使用for idx in range(0, min(num_boxes, max_vis)):
的原因。這樣,有了boxes和對應的類別class_name = COCO_INSTANCE_CATEGORY_NAMES[out_labels[idx]]
,就可以可視化了。
二.模型是如何完成目標檢測的?
將3D張量映射到兩個張量
- 1.分類張量: shape爲[N, C+ 1]
- 2.邊界框張量: shape爲[N, 4]
《Recent Advances in Deep Learning for Object Detection》-2019
邊界框數量N如何確定?
傳統方法一一滑動窗策略
缺點:
- 1.重複計算量大
- 2.窗口大小難確定
利用卷積減少重複計算
重要概念:
- 特徵圖一個像素對應原圖一塊區域
三.深度學習目標檢測模型簡介
《Object Detection in 23 Years- A Survey》 -2019
按流程分爲: one- stage和two-stage
Faster RCNN——經典two stage檢測網絡
A. Faster RCNN 的 backbone structure 對圖像進行特徵提取,生成feature map。
B. feature map一部分會進入RPN網絡,RNP網絡會生成數十萬各候選框,再使用非極大值抑制(NMS),挑選2000個proposals,這2000個候選框(proposals)會疊加在上一步生成的feature map上,進行“摳圖”。生成子區域的特徵圖。
C. 子區域特徵圖在經過ROI Layer,進行池化操作,生成統一固定大小的feature map。
D. 上一步生成的feature map會經過一系列的全鏈接,進行邊界框迴歸(Regression)和類別分類(C+1 Softmax)
E. 訓練細節:生成的2000個候選框,會進一步篩選成512個,也就是輸入到Stage2的子區域特徵圖是512個。
Faster RCNN數據流
- Feature map: [256,h f,w_f]
- 2 Softmax : [num_anchors, h_f, w_f]
- Regressors : [num_anchors*4, h_f, w_f]
- NMS OUT: [n_proposals = 2000, 4]
- ROI Layer: [512, 256, 7, 7]
- FC1 FC2: [512, 1024]
- c+1 Sofmax: [512, c + 1]
- Regressors: [512, (C+1)*4]