代碼連接:https://github.com/xiguanlezz/Faster-RCNN
一、反向傳播
因爲Faster-RCNN的loss值是包含兩部分的,第一部分是先驗框即anchors和對應anchors_target的loss;第二部分是建議框即proposals和對應的proposals_target的loss。
1、anchors的loss
其實就是先剔除掉在圖片外面的先驗框,然後根據IOU來創建標籤,在__call___函數裏計算了迴歸值以及標籤信息。
import numpy as np
from utils.util import calculate_iou, get_inside_index, box2loc
class AnchorTargetCreator:
def __init__(self,
n_sample=256,
pos_iou_thresh=0.7,
neg_iou_thresh=0.3,
pos_ratio=0.5):
"""
function description: AnchorTargetCreator構造函數
:param n_sample: 256, target的總數量
:param pos_iou_thresh: 和boxes的iou的閾值,超過此值爲"正"樣本, label會置爲1
:param neg_iou_thresh: 和boxes的iou的閾值,低於此之爲"負"樣本, label會置爲0
:param pos_ratio: target總數量中"正"樣本的比例
"""
self.n_sample = n_sample
self.pos_iou_thresh = pos_iou_thresh
self.neg_iou_thresh = neg_iou_thresh
self.pos_ratio = pos_ratio # target總數量中"正"樣本,如果正樣本數量不足,則填充負樣本
def __call__(self, boxes, anchors, img_size):
"""
function description: 得到先驗框對應的迴歸值和的labels
:param boxes: 圖片中真實框左上角和右下角的座標, 維度: [boxes_num, 4]
:param anchors: 根據featuremap生成的所有anchors的座標, 維度: [anchors_num, 4]
:param img_size: 原圖的大小, 用來過濾掉出界的anchors
:return:
anchor_locs: 最終的座標, 維度爲[inside_anchors_num ,4]
anchor_labels: 最終的標籤, 維度爲[inside_anchors_num]
"""
img_width, img_height = img_size
inside_index = get_inside_index(anchors, img_width, img_height)
# 根據index取到在圖片內部的anchors
inside_anchors = anchors[inside_index]
# 返回維度都爲[inside_anchors_num]的每個先驗框對應的iou最大的真實框的索引及打好的標籤
argmax_ious, labels = self._create_label(inside_anchors, boxes)
# 計算inside_anchors和對應iou最大的boxes的迴歸值
locs = box2loc(inside_anchors, boxes[argmax_ious])
anchors_num = len(anchors)
# 把inside_anchors重新展開回原來所有的anchors方便計算第一部分關於先驗框的loss
anchor_labels = np.empty((anchors_num,), dtype=labels.dtype)
anchor_labels.fill(-1)
anchor_labels[inside_index] = labels
# 利用broadcast重新展開locs方便計算第一部分關於先驗框的loss
anchor_locs = np.empty((anchors_num,) + locs.shape[1:], dtype=locs.dtype)
anchor_locs.fill(0)
anchor_locs[inside_index, :] = locs
return anchor_locs, anchor_labels
def _create_label(self, inside_anchors, boxes):
"""
function description: 爲每個inside_anchors創建一個label, 其中1表示正樣本, 0表示負樣本, -1則忽略
所有打標籤的情況:
1、與真實框的iou最高的先驗框的分配爲正樣本;
2、與真實框的iou大於pos_iou_thresh的分配爲正樣本;
3、與真實框的iou小於neg_iou_thresh的分配爲負樣本
:param inside_anchors: 在圖片內的先驗框(anchors), 維度爲: [inside_anchors_num, 4]
:param boxes: 圖片中的真實標註框, 維度爲: [boxes_num, 4]
:return:
argmax_ious: 每個先驗框對應的iou最大的真實框的索引, 維度爲: [inside_anchors_num]
label: 爲每個inside_anchors創建的label, 維度爲: [inside_anchors_num]
"""
# 對於每個在圖片內的anchor都生成一個label
label = np.empty((len(inside_anchors)), dtype=np.int32)
# 先將label初始化爲-1, 默認爲忽略的label
label.fill(-1)
# argmax_ious, max_ious, gt_argmax_ious維度都爲: [inside_anchors_num]
argmax_ious, max_ious, gt_argmax_ious = self._calculate_iou(inside_anchors, boxes)
# 將與真實框的iou重疊最大的anchors設置爲正樣本(分配每個真實框至少對應一個先驗框); 對應情況(a)
label[gt_argmax_ious] = 1
# 大於正樣本的閾值則設置爲正樣本即將label設置爲1; 對應情況(b)
label[max_ious >= self.pos_iou_thresh] = 1
# 小於負樣本的閾值就設置爲負樣本即將label設置爲0; 對應情況(c)
label[max_ious < self.neg_iou_thresh] = 0
# 下面的代碼都是平衡正負樣本,保持總數量爲256(忽略-1的錨點)
pos_standard = int(self.pos_ratio * self.n_sample)
pos_num = np.where(label == 1)[0]
if len(pos_num) > pos_standard:
# replace=False表示隨機選擇索引的時候不會重複
disable_index = np.random.choice(pos_num, size=(len(pos_num) - pos_standard), replace=False)
label[disable_index] = -1
neg_standard = self.n_sample - np.sum(label == 1) # 非正樣本的個數
neg_num = np.where(label == 0)[0]
if len(neg_num) > neg_standard:
disable_index = np.random.choice(neg_num, size=(len(neg_num) - neg_standard), replace=False)
label[disable_index] = -1
return argmax_ious, label
def _calculate_iou(self, inside_anchors, boxes):
"""
function description: 從二維iou張量中獲得每個先驗框對應的iou最大的真實框的索引以及相應iou的值
:param inside_anchors: 在圖片內的先驗框(anchors)
:param boxes: 圖片中的真實框
:return:
argmax_ious: 每個inside_anchor對應所有boxes中的最高iou的索引, 維度爲: [inside_anchors_num]
max_ious: 每個inside_anchor對應所有boxes中的最高iou, 維度爲: [inside_anchors_num]
gt_argmax_ious: 每個box對應所有inside_anchors中的最高iou的索引, 維度爲: [inside_anchors_num]
"""
# 第一個維度是先驗框的個數(inside_anchors_num), 第二個維度是真實框的個數(boxes_num)
ious = calculate_iou(inside_anchors, boxes)
argmax_ious = ious.argmax(axis=1) # 維度爲:[inside_num]
# 取到每個先驗框對應的真實框最大的iou
# TODO 將第一個維度從np.arange(len(inside_anchors))改爲np.arange(len(ious))
max_ious = ious[np.arange(len(ious)), argmax_ious]
gt_argmax_ious = ious.argmax(axis=0) # 維度爲:[boxes_num]
# 取到每個真實框對應的先驗框最大的iou
gt_max_ious = ious[gt_argmax_ious, np.arange(ious.shape[1])]
gt_argmax_ious = np.where(ious == gt_max_ious)[0]
return argmax_ious, max_ious, gt_argmax_ious
2、proposals的loss
這個部分主要邏輯就是保持正負樣本的均衡性,在__call___函數裏計算了迴歸值並將rois打上相應的標籤信息。
import numpy as np
from utils.util import calculate_iou, box2loc
class ProposalTargetCreator:
def __init__(self,
n_sample=128,
pos_ratio=0.25,
pos_iou_thresh=0.5,
neg_iou_thresh_hi=0.5,
neg_iou_thresh_lo=0.0):
"""
function description: 採樣128正負樣本個傳入FastRCNN的網絡
:param n_sample: 需要採樣的數量
:param pos_ratio: 正樣本比例
:param pos_iou_thresh: 正樣本閾值
:param neg_iou_thresh_hi: 負樣本最大閾值
:param neg_iou_thresh_lo: 負樣本最低閾值
:return:
sample_rois: 採樣後的感興趣區域
gt_roi_labels: boxes的標籤
gt_roi_locs: sample_rois和boxes的線性迴歸值
"""
self.n_sample = n_sample
self.pos_ratio = pos_ratio
self.pos_iou_thresh = pos_iou_thresh
self.neg_iou_thresh_hi = neg_iou_thresh_hi
self.neg_iou_thresh_lo = neg_iou_thresh_lo
def __call__(self,
rois,
boxes,
labels,
loc_normalize_mean=(0., 0., 0., 0.),
loc_normalize_std=(0.1, 0.1, 0.2, 0.2)):
"""
function description: 得到採樣後的rois, 及其對應的labels和迴歸值
:param rois: rpn輸入的rois
:param boxes: 一幅圖的位置標註
:param labels: 一幅圖的類別標註
:param loc_normalize_mean: 均值
:param loc_normalize_std: 標準差
:return:
"""
n_bbox, _ = boxes.shape
# 取到正樣本的個數(四捨五入)
pos_num = np.round(self.n_sample * self.pos_ratio)
ious = calculate_iou(rois, boxes)
gt_assignment = ious.argmax(axis=1) # 返回維度爲[rois_num]
max_iou = ious.max(axis=1)
gt_roi_labels = labels[gt_assignment] # 返回維度爲[rois_num]
# 篩選出其中iou滿足閾值的部分
pos_index = np.where(max_iou >= self.pos_iou_thresh)[0]
pos_num_for_this_image = int(min(pos_num, pos_index.size))
if pos_index.size > 0:
pos_index = np.random.choice(pos_index, size=pos_num_for_this_image, replace=False)
# 篩選出其中iou不滿足閾值的部分
neg_index = np.where((max_iou < self.neg_iou_thresh_hi) & (max_iou >= self.neg_iou_thresh_lo))[0]
neg_num = self.n_sample - pos_num_for_this_image
neg_num_for_this_image = int(min(neg_index.size, neg_num))
if neg_index.size > 0:
neg_index = np.random.choice(neg_index, size=neg_num_for_this_image, replace=False)
keep_index = np.append(pos_index, neg_index)
gt_roi_labels = gt_roi_labels[keep_index]
gt_roi_labels[pos_num_for_this_image:] = 0 # 背景標記爲0, pos_num_for_this_image及之後的索引都標爲0
sample_rois = rois[keep_index]
gt_roi_locs = box2loc(sample_rois, boxes[gt_assignment[keep_index]])
return sample_rois, gt_roi_labels, gt_roi_locs
3、總loss
先來看一下論文中對總loss公式的定義:
至於代碼中的實現相當於是加了不同的權重,總的loss值主要是第二部分loss中的,而且只計算label中爲正樣本的loss值(因爲label爲0表示背景,會略背景的loss計算)。
def smooth_l1_loss(x, t, in_weight, sigma):
"""
function description: 計算L1損失函數
:param x: 輸出的位置信息
:param t: 標註的位置信息
:param in_weight: 篩選矩陣, 非正樣本的地方爲0
:param sigma:
:return:
"""
sigma2 = sigma ** 2
diff = in_weight * (x - t)
abs_diff = diff.abs()
flag = (abs_diff.data < (1. / sigma2)).float()
# TODO loss的計算
y = (flag * (sigma2 / 2.) * (diff ** 2) + (1 - flag) * (abs_diff - 0.5 / sigma2))
return y.sum()
def loc_loss(pred_loc, gt_loc, gt_label, sigma):
"""
function description: 僅對正樣本進行loc_loss值的計算
:param pred_loc: 輸出的位置信息
:param gt_loc: 標註的位置信息
:param gt_label: 標註的類別
:param sigma:
:return:
"""
in_weight = torch.zeros(gt_loc.shape).cuda()
# 用作篩選矩陣, 維度爲[gt_label_num, 4]
in_weight[(gt_label > 0).view(-1, 1).expand_as(in_weight)] = 1
loc_loss = smooth_l1_loss(pred_loc, gt_loc, in_weight.detach(), sigma)
loc_loss /= ((gt_label >= 0).sum().float())
return loc_loss
二、Faster-RCNN代碼
可以將上篇文章中的網絡再看看,這裏就是將之前講過的網絡組合起來,並計算一個loss。
from torch import nn
import torch.nn.functional as F
from nets.vgg16 import decom_VGG16
from nets.rpn import RPN
from nets.anchor_target_creator import AnchorTargetCreator
from nets.proposal_target_creator import ProposalTargetCreator
from nets.fast_rcnn import FastRCNN
from utils.util import loc_loss
from collections import namedtuple
import torch
from utils.util import loc2box, non_maximum_suppression
import numpy as np
from configs.config import class_num, device_name
LossTuple = namedtuple('LossTuple',
['rpn_loc_loss',
'rpn_cls_loss',
'roi_loc_loss',
'roi_cls_loss',
'total_loss'
])
device = torch.device(device_name)
class FasterRCNN(nn.Module):
def __init__(self, path):
super(FasterRCNN, self).__init__()
self.extractor, classifier = decom_VGG16(path)
self.rpn = RPN()
self.anchor_target_creator = AnchorTargetCreator()
self.sample_rois = ProposalTargetCreator()
self.fast_rcnn = FastRCNN(n_class=class_num, roi_size=7, spatial_scale=1. / 16, classifier=classifier)
# 係數,用來計算l1_smooth_loss
self.rpn_sigma = 3.
self.roi_sigma = 1.
def forward(self, x, gt_boxes, labels):
# -----------------part 1: feature 提取部分----------------------
h = self.extractor(x)
# -----------------part 2: rpn部分(output_1)----------------------
img_size = (x.size(2), x.size(3))
# rpn_locs維度爲: [batch_size, w, h, 4*k], 類型是pytorch的張量
# rpn_scores維度爲: [batch_size, w, h, k], 類型是pytorch的張量
# anchors維度爲: [batch_size, w*h*k, 4], 類型是numpy數組
# rois維度爲: [w*h*k ,4]
rpn_locs, rpn_scores, anchors, rois = self.rpn(h, img_size)
# gt_anchor_locs維度爲: [anchors_num, 4], gt_anchor_labels維度爲:[anchors_num]
# gt_anchor_labels這個labels如果爲1表示先驗框內有物體, 0表示先驗框內沒有物體
gt_anchor_locs, gt_anchor_labels = self.anchor_target_creator(gt_boxes[0].detach().cpu().numpy(),
anchors,
img_size)
# ----------------part 3: roi採樣部分----------------------------
# gt_roi_labels這個labels表示rois所屬類別
sample_rois, gt_roi_labels, gt_roi_locs = self.sample_rois(rois,
gt_boxes[0].detach().cpu().numpy(),
labels[0].detach().cpu().numpy())
# ---------------part 4: fast rcnn(roi)部分(output_2)------------
# roi_cls_locs維度爲: [batch_size, 4], roi_scores維度爲:[batch_size, 1]
roi_locs, roi_scores = self.fast_rcnn(h, sample_rois)
# RPN LOSS
gt_anchor_locs = torch.from_numpy(gt_anchor_locs).to(device)
gt_anchor_labels = torch.from_numpy(gt_anchor_labels).long().to(device)
# rpn_scores[0]表示類別標籤, rpn_scores[1]表示置信度
rpn_cls_loss = F.cross_entropy(rpn_scores[0], gt_anchor_labels, ignore_index=-1) # label值爲-1的不參與loss值的計算
rpn_loc_loss = loc_loss(rpn_locs[0], gt_anchor_locs, gt_anchor_labels, self.rpn_sigma)
# ROI LOSS
gt_roi_labels = torch.from_numpy(gt_roi_labels).long().to(device)
gt_roi_locs = torch.from_numpy(gt_roi_locs).float().to(device)
roi_cls_loss = F.cross_entropy(roi_scores, gt_roi_labels)
n_sample = roi_locs.shape[0] # batch_size
roi_cls_locs = roi_locs.view(n_sample, -1, 4)
roi_locs = roi_cls_locs[torch.arange(0, n_sample).long(), gt_roi_labels]
roi_loc_loss = loc_loss(roi_locs.contiguous(), gt_roi_locs, gt_roi_labels, self.roi_sigma)
losses = [rpn_loc_loss, rpn_cls_loss, roi_loc_loss, roi_cls_loss]
losses = losses + [sum(losses)]
return LossTuple(*losses)
@torch.no_grad()
def predict(self, x):
# 設置爲測試模式, 改變rpn網絡中n_post_nms的閾值爲300
self.eval()
# -----------------part 1: feature 提取部分----------------------
h = self.extractor(x)
img_size = (x.size(2), x.size(3))
# ----------------------part 2: rpn部分--------------------------
rpn_locs, rpn_socres, anchors, rois = self.rpn(h, img_size)
# ------------------part 3: fast rcnn(roi)部分-------------------
# 先經過Roi pooling層, 在經過兩個全連接層
roi_locs, roi_scores = self.fast_rcnn(h, np.asarray(rois))
n_sample = roi_locs.shape[0]
# --------------------part 4:boxes生成部分-----------------------
roi_cls_locs = roi_locs.view(n_sample, -1, 4)
rois = torch.from_numpy(rois).to(device)
rois = rois.view(-1, 1, 4).expand_as(roi_cls_locs)
boxes = loc2box(rois.cpu().numpy().reshape((-1, 4)), roi_cls_locs.cpu().numpy().reshape((-1, 4)))
boxes = torch.from_numpy(boxes).to(device)
# 修剪boxes中的座標, 使其落在圖片內
boxes[:, [0, 2]] = (boxes[:, [0, 2]]).clamp(min=0, max=img_size[0])
boxes[:, [1, 3]] = (boxes[:, [1, 3]]).clamp(min=0, max=img_size[1])
boxes = boxes.view(n_sample, -1)
# roi_scores轉換爲概率, prob維度爲[rois_num, 7]
prob = F.softmax(roi_scores, dim=1)
# ----------------part 5:篩選環節------------------------
raw_boxes = boxes.cpu().numpy()
raw_prob = prob.cpu().numpy()
final_boxes, labels, scores = self._suppress(raw_boxes, raw_prob)
self.train()
return final_boxes, labels, scores
def _suppress(self, raw_boxes, raw_prob):
# print(raw_prob.shape)
score_thresh = 0.7
nms_thresh = 0.3
n_class = class_num
box = list()
label = list()
score = list()
for i in range(1, class_num):
box_i = raw_boxes.reshape((-1, n_class, 4))
box_i = box_i[:, i, :] # 維度爲: [rois_num, k, 4]
prob_i = raw_prob[:, i] # 維度爲: [rois_num]
mask = prob_i > score_thresh
box_i = box_i[mask]
prob_i = prob_i[mask]
order = prob_i.argsort()[::-1]
# 按照score值從大到小進行排序
box_i = box_i[order]
box_i_after_nms, keep = non_maximum_suppression(box_i, nms_thresh)
box.append(box_i_after_nms)
label_i = (i - 1) * np.ones((len(keep),))
label.append(label_i)
score.append(prob_i[keep])
box = np.concatenate(box, axis=0).astype(np.float32)
label = np.concatenate(label, axis=0).astype(np.int32)
score = np.concatenate(score, axis=0).astype(np.float32)
return box, label, score
三、數據集部分
1、生成txt文件
寫了兩種生成txt文件的代碼。
① 數據集給的是txt標註
第一種基於的前提是數據集給的是txt標註,那可以用下面的函數生成4個txt並生成對應的xml文件。
from lxml import etree as ET
import glob
import cv2
import random
from configs.config import classes_for_label, xml_root_dir, img_root_dir, txt_root_dir, pic_format
import numpy as np
from PIL import Image
def write_xml(filename, saveimg, typename, boxes, xmlpath):
"""
function description: 將txt的標註文件轉爲xml
:param filename: 圖片名
:param saveimg: opencv讀取圖片張量
:param typename: 類名
:param boxes: 左上角和右下角座標
:param xmlpath: 保存的xml文件名
"""
# 根節點
root = ET.Element("annotation")
# folder節點
folder_node = ET.SubElement(root, 'folder')
folder_node.text = 'kitti'
# filename節點
filename_node = ET.SubElement(root, 'filename')
filename_node.text = filename
# source節點
source_node = ET.SubElement(root, 'source')
database_node = ET.SubElement(source_node, 'database')
database_node.text = 'kitti Database'
annotation_node = ET.SubElement(source_node, 'annotation')
annotation_node.text = 'kitti'
image_node = ET.SubElement(source_node, 'image')
image_node.text = 'flickr'
flickrid_node = ET.SubElement(source_node, 'flickrid')
flickrid_node.text = '-1'
# owner節點
owner_node = ET.SubElement(root, 'owner')
flickrid_node = ET.SubElement(owner_node, 'flickrid')
flickrid_node.text = 'muke'
name_node = ET.SubElement(owner_node, 'name')
name_node.text = 'muke'
# size節點
size_node = ET.SubElement(root, 'size')
width_node = ET.SubElement(size_node, 'width')
width_node.text = str(saveimg.shape[1])
height_node = ET.SubElement(size_node, 'height')
height_node.text = str(saveimg.shape[0])
depth_node = ET.SubElement(size_node, 'depth')
depth_node.text = str(saveimg.shape[2])
# segmented節點(用於圖像分割)
segmented_node = ET.SubElement(root, 'segmented')
segmented_node.text = '0'
# object節點(循環添加節點)
for i in range(len(typename)):
object_node = ET.SubElement(root, 'object')
name_node = ET.SubElement(object_node, 'name')
name_node.text = typename[i]
pose_node = ET.SubElement(object_node, 'pose')
pose_node.text = 'Unspecified'
# 是否截斷
truncated_node = ET.SubElement(object_node, 'truncated')
truncated_node.text = '1'
difficult_node = ET.SubElement(object_node, 'difficult')
difficult_node.text = '0'
bndbox_node = ET.SubElement(object_node, 'bndbox')
xmin_node = ET.SubElement(bndbox_node, 'xmin')
xmin_node.text = str(boxes[i][0])
ymin_node = ET.SubElement(bndbox_node, 'ymin')
ymin_node.text = str(boxes[i][1])
xmax_node = ET.SubElement(bndbox_node, 'xmax')
xmax_node.text = str(boxes[i][2])
ymax_node = ET.SubElement(bndbox_node, 'ymax')
ymax_node.text = str(boxes[i][3])
tree = ET.ElementTree(root)
tree.write(xmlpath, pretty_print=True)
def split_dataset_byTXT():
"""
function description: 根據總訓練集標註的txt文件將其數據集切分爲訓練集, 驗證集以及測試集, 並且寫入相應的xml作爲標註
"""
trainval = open(txt_root_dir + 'trainval.txt', 'w')
train = open(txt_root_dir + 'train.txt', 'w')
val = open(txt_root_dir + 'val.txt', 'w')
test = open(txt_root_dir + 'train_test.txt', 'w')
list_anno_files = glob.glob(train_label_path + "*")
random.shuffle(list_anno_files)
index = 0
for anno_file in list_anno_files:
with open(anno_file) as file:
boxes = []
typename = []
anno_infos = file.readlines()
for anno_item in anno_infos:
anno_new_infos = anno_item.split(" ")
# 去掉雜項和不關心這倆類別
if anno_new_infos[0] == "Misc" or anno_new_infos[0] == "DontCare":
continue
else:
box = (int(float(anno_new_infos[4])), int(float(anno_new_infos[5])),
int(float(anno_new_infos[6])), int(float(anno_new_infos[7])))
boxes.append(box)
typename.append(anno_new_infos[0])
filename = anno_file.split("\\")[-1].replace(".txt", pic_format)
xmlpath = xml_root_dir + filename.replace(pic_format, ".xml")
imgpath = img_root_dir + 'training/' + filename
print(imgpath)
saveimg = cv2.imread(imgpath)
write_xml(filename, saveimg, typename, boxes, xmlpath)
index += 1
if index > len(list_anno_files) * 0.9:
test.write(filename.replace(pic_format, "\n"))
else:
trainval.write(filename.replace(pic_format, "\n"))
if index > len(list_anno_files) * 0.7:
val.write(filename.replace(pic_format, "\n"))
else:
train.write(filename.replace(pic_format, "\n"))
trainval.close()
train.close()
val.close()
test.close()
② 數據集的標註直接是xml文件
第二種基於的前提是數據集的標註直接是xml文件,那直接根據文件名生成txt文件就OK了。
def split_dataset_byXML():
"""
function description: 根據總訓練集的XML標註文件將其切分爲訓練集, 驗證集以及測試集
"""
trainval = open(txt_root_dir + 'trainval.txt', 'w')
train = open(txt_root_dir + 'train.txt', 'w')
val = open(txt_root_dir + 'val.txt', 'w')
train_test = open(txt_root_dir + 'train_test.txt', 'w')
list_anno_files = glob.glob(xml_root_dir + "*")
random.shuffle(list_anno_files)
index = 0
for anno_file in list_anno_files:
filename = anno_file.replace(".xml", pic_format)
index += 1
if index > len(list_anno_files) * 0.9:
train_test.write(filename.replace(pic_format, "\n"))
else:
trainval.write(filename.replace(pic_format, "\n"))
if index > len(list_anno_files) * 0.7:
val.write(filename.replace(pic_format, "\n"))
else:
train.write(filename.replace(pic_format, "\n"))
trainval.close()
train.close()
val.close()
train_test.close()
2、實現Dataset的類
因爲考慮到真正的測試集是沒有標註這麼一說的,所以__getitem__函數返回的內容也應該不是一樣的。對於測試集和訓練集想要最大程度地複用代碼,在所以在構造函數裏面傳了一個標記位,用來區分是train還是test。而且Faster-RCNN訓練所需要圖片的尺寸是有要求的,最小的邊必須超過600px,否則在Roi pooling的時候會出現問題,但是預測不準確,所以我在代碼裏面還是用了reshape函數。
注意:對於訓練集直接reshape那就大錯特錯了,你需要在縮放圖片的同時等比例縮放標註框的位置!下面給的代碼都實現了,都是直接將張量和標註拉到內存,所以佔用的內存空間會很大。
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import os
from data.process_data import parse_xml, reshape
import numpy as np
from PIL import Image
from configs.config import pic_format
class ImageDataset(Dataset):
def __init__(self, xml_root_dir, img_root_dir, txt_root_dir, txt_file, isTest=False, transform=None):
"""
class description: 這個類已經將最小邊縮放到600px了, 同時將訓練集中標註的位置也等比例修改了
:param xml_root_dir: xml標註文件的根路徑
:param img_root_dir: img圖片的根路徑
:param txt_root_dir: txt文件的根路徑
:param txt_file: txt文件名
:param isTest: 標誌是否是測試集
:param transform: 變換
"""
super(ImageDataset, self).__init__()
self.xml_root_dir = xml_root_dir
self.img_root_dir = img_root_dir
self.txt_root_dir = txt_root_dir
self.txt_file = txt_file
self.isTest = isTest
if transform == None:
self.transform = transforms.Compose([
# TODO BUG的根源... 爲了適配vgg16的輸入
# transforms.Resize((int(224), int(224))),
transforms.ToTensor(),
transforms.Normalize(
[0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])
])
if self.isTest == False:
boxes, labels, images = self.load_txt(self.txt_file)
self.boxes = boxes
self.labels = labels
self.images = images
elif self.isTest == True:
self.images = self.load_txt(self.txt_file)
id_list_files = os.path.join(txt_root_dir, txt_file)
self.ids = [id_.strip() for id_ in open(id_list_files)]
def load_txt(self, filename):
"""
function description: 加載txt文件中的信息並放到numpy數組中, numpy可以直接在list中再次添加可變list
:param filename: txt文件名
"""
print('-------------the file name is ', filename)
boxes = []
labels = []
images = []
print(os.path.join(self.txt_root_dir, filename))
with open(os.path.join(self.txt_root_dir, filename), mode='r') as f:
lines = f.readlines()
# index = 0
for line in lines:
line = line.strip()
if self.isTest == False:
box, label, image = self.load_xml(line + ".xml")
boxes.append(box)
labels.append(label)
# index += 1
elif self.isTest == True:
image = (line + pic_format)
# image = line.replace("\n", ".jpg")
images.append(image)
if self.isTest == False:
print('the length of boxes is ', len(boxes))
print('the length of labels is ', len(labels))
print('the length of images is ', len(images))
return boxes, labels, images
elif self.isTest == True:
return images
def load_xml(self, filename):
"""
function description: 加載xml文件中需要的屬性並將最小邊縮放爲600
:param filename: xml文件名
"""
path = os.path.join(self.xml_root_dir, filename)
if not os.path.exists(path):
return
boxes, labels = parse_xml(path)
img_name = filename.replace(".xml", pic_format)
images, boxes = reshape(Image.open(self.img_root_dir + img_name), boxes)
return np.stack(boxes).astype(np.float32), \
np.stack(labels).astype(np.int32), \
images
def __len__(self):
return len(self.images)
def __getitem__(self, index):
if self.isTest == False:
id = self.ids[index]
box, label, image = self.load_xml('{0}.xml'.format(id))
img_tensor = self.transform(image)
# [channel, height, width] -> [channel, width, height]
img_tensor = img_tensor.permute(0, 2, 1)
return {
"img_name": id + pic_format,
"img_tensor": img_tensor,
"img_classes": label,
"img_gt_boxes": box
}
elif self.isTest == True:
img = Image.open(self.img_root_dir + self.images[index])
img_tensor = self.transform(img)
img_tensor = img_tensor.permute(0, 2, 1)
return {
"img_name": self.images[index],
"img_tensor": img_tensor,
}
四、測試部分
下圖是我kitti數據集在我代碼上面跑了一個epoch之後進行預測的結果。
效果還是不錯的。