Faster-RCNN全面解讀(手把手帶你分析代碼實現)---完結篇

代碼連接https://github.com/xiguanlezz/Faster-RCNN



一、反向傳播

       因爲Faster-RCNN的loss值是包含兩部分的,第一部分是先驗框即anchors和對應anchors_target的loss;第二部分是建議框即proposals和對應的proposals_target的loss。


1、anchors的loss

       其實就是先剔除掉在圖片外面的先驗框,然後根據IOU來創建標籤,在__call___函數裏計算了迴歸值以及標籤信息。

	import numpy as np
	from utils.util import calculate_iou, get_inside_index, box2loc


	class AnchorTargetCreator:
	    def __init__(self,
	                 n_sample=256,
	                 pos_iou_thresh=0.7,
	                 neg_iou_thresh=0.3,
	                 pos_ratio=0.5):
	        """
	        function description: AnchorTargetCreator構造函數
	
	        :param n_sample: 256, target的總數量
	        :param pos_iou_thresh: 和boxes的iou的閾值,超過此值爲"正"樣本, label會置爲1
	        :param neg_iou_thresh: 和boxes的iou的閾值,低於此之爲"負"樣本, label會置爲0
	        :param pos_ratio: target總數量中"正"樣本的比例
	        """
	        self.n_sample = n_sample
	        self.pos_iou_thresh = pos_iou_thresh
	        self.neg_iou_thresh = neg_iou_thresh
	        self.pos_ratio = pos_ratio  # target總數量中"正"樣本,如果正樣本數量不足,則填充負樣本
	
	    def __call__(self, boxes, anchors, img_size):
	        """
	        function description: 得到先驗框對應的迴歸值和的labels
	
	        :param boxes: 圖片中真實框左上角和右下角的座標, 維度: [boxes_num, 4]
	        :param anchors: 根據featuremap生成的所有anchors的座標, 維度: [anchors_num, 4]
	        :param img_size: 原圖的大小, 用來過濾掉出界的anchors
	        :return:
	            anchor_locs: 最終的座標, 維度爲[inside_anchors_num ,4]
	            anchor_labels: 最終的標籤, 維度爲[inside_anchors_num]
	        """
	        img_width, img_height = img_size
	
	        inside_index = get_inside_index(anchors, img_width, img_height)
	        # 根據index取到在圖片內部的anchors
	        inside_anchors = anchors[inside_index]
	        # 返回維度都爲[inside_anchors_num]的每個先驗框對應的iou最大的真實框的索引及打好的標籤
	        argmax_ious, labels = self._create_label(inside_anchors, boxes)
	
	        # 計算inside_anchors和對應iou最大的boxes的迴歸值
	        locs = box2loc(inside_anchors, boxes[argmax_ious])
	
	        anchors_num = len(anchors)
	
	        # 把inside_anchors重新展開回原來所有的anchors方便計算第一部分關於先驗框的loss
	        anchor_labels = np.empty((anchors_num,), dtype=labels.dtype)
	        anchor_labels.fill(-1)
	        anchor_labels[inside_index] = labels
	        # 利用broadcast重新展開locs方便計算第一部分關於先驗框的loss
	        anchor_locs = np.empty((anchors_num,) + locs.shape[1:], dtype=locs.dtype)
	        anchor_locs.fill(0)
	        anchor_locs[inside_index, :] = locs
	        return anchor_locs, anchor_labels
	
	    def _create_label(self, inside_anchors, boxes):
	        """
	        function description: 爲每個inside_anchors創建一個label, 其中1表示正樣本, 0表示負樣本, -1則忽略
	                              所有打標籤的情況:
	                                1、與真實框的iou最高的先驗框的分配爲正樣本;
	                                2、與真實框的iou大於pos_iou_thresh的分配爲正樣本;
	                                3、與真實框的iou小於neg_iou_thresh的分配爲負樣本
	
	        :param inside_anchors: 在圖片內的先驗框(anchors), 維度爲: [inside_anchors_num, 4]
	        :param boxes: 圖片中的真實標註框, 維度爲: [boxes_num, 4]
	        :return:
	            argmax_ious: 每個先驗框對應的iou最大的真實框的索引, 維度爲: [inside_anchors_num]
	            label: 爲每個inside_anchors創建的label, 維度爲: [inside_anchors_num]
	        """
	        # 對於每個在圖片內的anchor都生成一個label
	        label = np.empty((len(inside_anchors)), dtype=np.int32)
	        # 先將label初始化爲-1, 默認爲忽略的label
	        label.fill(-1)
	
	        # argmax_ious, max_ious, gt_argmax_ious維度都爲: [inside_anchors_num]
	        argmax_ious, max_ious, gt_argmax_ious = self._calculate_iou(inside_anchors, boxes)
	
	        # 將與真實框的iou重疊最大的anchors設置爲正樣本(分配每個真實框至少對應一個先驗框); 對應情況(a)
	        label[gt_argmax_ious] = 1
	        # 大於正樣本的閾值則設置爲正樣本即將label設置爲1; 對應情況(b)
	        label[max_ious >= self.pos_iou_thresh] = 1
	        # 小於負樣本的閾值就設置爲負樣本即將label設置爲0; 對應情況(c)
	        label[max_ious < self.neg_iou_thresh] = 0
	
	        # 下面的代碼都是平衡正負樣本,保持總數量爲256(忽略-1的錨點)
	        pos_standard = int(self.pos_ratio * self.n_sample)
	        pos_num = np.where(label == 1)[0]
	        if len(pos_num) > pos_standard:
	            # replace=False表示隨機選擇索引的時候不會重複
	            disable_index = np.random.choice(pos_num, size=(len(pos_num) - pos_standard), replace=False)
	            label[disable_index] = -1
	        neg_standard = self.n_sample - np.sum(label == 1)  # 非正樣本的個數
	        neg_num = np.where(label == 0)[0]
	        if len(neg_num) > neg_standard:
	            disable_index = np.random.choice(neg_num, size=(len(neg_num) - neg_standard), replace=False)
	            label[disable_index] = -1
	        return argmax_ious, label
	
	    def _calculate_iou(self, inside_anchors, boxes):
	        """
	        function description: 從二維iou張量中獲得每個先驗框對應的iou最大的真實框的索引以及相應iou的值
	
	        :param inside_anchors: 在圖片內的先驗框(anchors)
	        :param boxes: 圖片中的真實框
	        :return:
	            argmax_ious: 每個inside_anchor對應所有boxes中的最高iou的索引, 維度爲: [inside_anchors_num]
	            max_ious: 每個inside_anchor對應所有boxes中的最高iou, 維度爲: [inside_anchors_num]
	            gt_argmax_ious: 每個box對應所有inside_anchors中的最高iou的索引, 維度爲: [inside_anchors_num]
	        """
	        # 第一個維度是先驗框的個數(inside_anchors_num), 第二個維度是真實框的個數(boxes_num)
	        ious = calculate_iou(inside_anchors, boxes)
	
	        argmax_ious = ious.argmax(axis=1)  # 維度爲:[inside_num]
	        # 取到每個先驗框對應的真實框最大的iou
	        # TODO 將第一個維度從np.arange(len(inside_anchors))改爲np.arange(len(ious))
	        max_ious = ious[np.arange(len(ious)), argmax_ious]
	
	        gt_argmax_ious = ious.argmax(axis=0)  # 維度爲:[boxes_num]
	        # 取到每個真實框對應的先驗框最大的iou
	        gt_max_ious = ious[gt_argmax_ious, np.arange(ious.shape[1])]
	
	        gt_argmax_ious = np.where(ious == gt_max_ious)[0]
	        return argmax_ious, max_ious, gt_argmax_ious



2、proposals的loss

       這個部分主要邏輯就是保持正負樣本的均衡性,在__call___函數裏計算了迴歸值並將rois打上相應的標籤信息。

	import numpy as np
	from utils.util import calculate_iou, box2loc
	
	
	class ProposalTargetCreator:
	    def __init__(self,
	                 n_sample=128,
	                 pos_ratio=0.25,
	                 pos_iou_thresh=0.5,
	                 neg_iou_thresh_hi=0.5,
	                 neg_iou_thresh_lo=0.0):
	        """
	        function description: 採樣128正負樣本個傳入FastRCNN的網絡
	
	        :param n_sample: 需要採樣的數量
	        :param pos_ratio: 正樣本比例
	        :param pos_iou_thresh: 正樣本閾值
	        :param neg_iou_thresh_hi: 負樣本最大閾值
	        :param neg_iou_thresh_lo: 負樣本最低閾值
	        :return:
	            sample_rois: 採樣後的感興趣區域
	            gt_roi_labels: boxes的標籤
	            gt_roi_locs: sample_rois和boxes的線性迴歸值
	        """
	        self.n_sample = n_sample
	        self.pos_ratio = pos_ratio
	        self.pos_iou_thresh = pos_iou_thresh
	        self.neg_iou_thresh_hi = neg_iou_thresh_hi
	        self.neg_iou_thresh_lo = neg_iou_thresh_lo
	
	    def __call__(self,
	                 rois,
	                 boxes,
	                 labels,
	                 loc_normalize_mean=(0., 0., 0., 0.),
	                 loc_normalize_std=(0.1, 0.1, 0.2, 0.2)):
	        """
	        function description: 得到採樣後的rois, 及其對應的labels和迴歸值
	
	        :param rois: rpn輸入的rois
	        :param boxes: 一幅圖的位置標註
	        :param labels: 一幅圖的類別標註
	        :param loc_normalize_mean: 均值
	        :param loc_normalize_std: 標準差
	        :return:
	        """
	        n_bbox, _ = boxes.shape
	
	        # 取到正樣本的個數(四捨五入)
	        pos_num = np.round(self.n_sample * self.pos_ratio)
	
	        ious = calculate_iou(rois, boxes)
	        gt_assignment = ious.argmax(axis=1)  # 返回維度爲[rois_num]
	        max_iou = ious.max(axis=1)
	
	        gt_roi_labels = labels[gt_assignment]  # 返回維度爲[rois_num]
	
	        # 篩選出其中iou滿足閾值的部分
	        pos_index = np.where(max_iou >= self.pos_iou_thresh)[0]
	        pos_num_for_this_image = int(min(pos_num, pos_index.size))
	        if pos_index.size > 0:
	            pos_index = np.random.choice(pos_index, size=pos_num_for_this_image, replace=False)
	        # 篩選出其中iou不滿足閾值的部分
	        neg_index = np.where((max_iou < self.neg_iou_thresh_hi) & (max_iou >= self.neg_iou_thresh_lo))[0]
	        neg_num = self.n_sample - pos_num_for_this_image
	        neg_num_for_this_image = int(min(neg_index.size, neg_num))
	        if neg_index.size > 0:
	            neg_index = np.random.choice(neg_index, size=neg_num_for_this_image, replace=False)
	
	        keep_index = np.append(pos_index, neg_index)
	        gt_roi_labels = gt_roi_labels[keep_index]
	        gt_roi_labels[pos_num_for_this_image:] = 0  # 背景標記爲0, pos_num_for_this_image及之後的索引都標爲0
	        sample_rois = rois[keep_index]
	
	        gt_roi_locs = box2loc(sample_rois, boxes[gt_assignment[keep_index]])
	
	        return sample_rois, gt_roi_labels, gt_roi_locs



3、總loss

       先來看一下論文中對總loss公式的定義:

Alt


       至於代碼中的實現相當於是加了不同的權重,總的loss值主要是第二部分loss中的,而且只計算label中爲正樣本的loss值(因爲label爲0表示背景,會略背景的loss計算)。

Alt


	def smooth_l1_loss(x, t, in_weight, sigma):
	    """
	    function description: 計算L1損失函數
	
	    :param x: 輸出的位置信息
	    :param t: 標註的位置信息
	    :param in_weight: 篩選矩陣, 非正樣本的地方爲0
	    :param sigma:
	    :return:
	    """
	    sigma2 = sigma ** 2
	    diff = in_weight * (x - t)
	    abs_diff = diff.abs()
	    flag = (abs_diff.data < (1. / sigma2)).float()
	    # TODO loss的計算
	    y = (flag * (sigma2 / 2.) * (diff ** 2) + (1 - flag) * (abs_diff - 0.5 / sigma2))
	    return y.sum()
	
	
	def loc_loss(pred_loc, gt_loc, gt_label, sigma):
	    """
	    function description: 僅對正樣本進行loc_loss值的計算
	
	    :param pred_loc: 輸出的位置信息
	    :param gt_loc: 標註的位置信息
	    :param gt_label: 標註的類別
	    :param sigma:
	    :return:
	    """
	    in_weight = torch.zeros(gt_loc.shape).cuda()
	    # 用作篩選矩陣, 維度爲[gt_label_num, 4]
	    in_weight[(gt_label > 0).view(-1, 1).expand_as(in_weight)] = 1
	    loc_loss = smooth_l1_loss(pred_loc, gt_loc, in_weight.detach(), sigma)
	
	    loc_loss /= ((gt_label >= 0).sum().float())
	    return loc_loss




二、Faster-RCNN代碼

       可以將上篇文章中的網絡再看看,這裏就是將之前講過的網絡組合起來,並計算一個loss。

	from torch import nn
	import torch.nn.functional as F
	from nets.vgg16 import decom_VGG16
	from nets.rpn import RPN
	from nets.anchor_target_creator import AnchorTargetCreator
	from nets.proposal_target_creator import ProposalTargetCreator
	from nets.fast_rcnn import FastRCNN
	from utils.util import loc_loss
	from collections import namedtuple
	import torch
	from utils.util import loc2box, non_maximum_suppression
	import numpy as np
	from configs.config import class_num, device_name
	
	LossTuple = namedtuple('LossTuple',
	                       ['rpn_loc_loss',
	                        'rpn_cls_loss',
	                        'roi_loc_loss',
	                        'roi_cls_loss',
	                        'total_loss'
	                        ])
	
	device = torch.device(device_name)
	
	
	class FasterRCNN(nn.Module):
	    def __init__(self, path):
	        super(FasterRCNN, self).__init__()
	
	        self.extractor, classifier = decom_VGG16(path)
	        self.rpn = RPN()
	        self.anchor_target_creator = AnchorTargetCreator()
	        self.sample_rois = ProposalTargetCreator()
	
	        self.fast_rcnn = FastRCNN(n_class=class_num, roi_size=7, spatial_scale=1. / 16, classifier=classifier)
	        # 係數,用來計算l1_smooth_loss
	        self.rpn_sigma = 3.
	        self.roi_sigma = 1.
	
	    def forward(self, x, gt_boxes, labels):
	        # -----------------part 1: feature 提取部分----------------------
	        h = self.extractor(x)
	
	        # -----------------part 2: rpn部分(output_1)----------------------
	        img_size = (x.size(2), x.size(3))
	        # rpn_locs維度爲: [batch_size, w, h, 4*k], 類型是pytorch的張量
	        # rpn_scores維度爲: [batch_size, w, h, k], 類型是pytorch的張量
	        # anchors維度爲: [batch_size, w*h*k, 4], 類型是numpy數組
	        # rois維度爲: [w*h*k ,4]
	        rpn_locs, rpn_scores, anchors, rois = self.rpn(h, img_size)
	        # gt_anchor_locs維度爲: [anchors_num, 4], gt_anchor_labels維度爲:[anchors_num]
	        # gt_anchor_labels這個labels如果爲1表示先驗框內有物體, 0表示先驗框內沒有物體
	        gt_anchor_locs, gt_anchor_labels = self.anchor_target_creator(gt_boxes[0].detach().cpu().numpy(),
	                                                                      anchors,
	                                                                      img_size)
	
	        # ----------------part 3: roi採樣部分----------------------------
	        # gt_roi_labels這個labels表示rois所屬類別
	        sample_rois, gt_roi_labels, gt_roi_locs = self.sample_rois(rois,
	                                                                   gt_boxes[0].detach().cpu().numpy(),
	                                                                   labels[0].detach().cpu().numpy())
	
	        # ---------------part 4: fast rcnn(roi)部分(output_2)------------
	        # roi_cls_locs維度爲: [batch_size, 4], roi_scores維度爲:[batch_size, 1]
	        roi_locs, roi_scores = self.fast_rcnn(h, sample_rois)
	
	        # RPN LOSS
	        gt_anchor_locs = torch.from_numpy(gt_anchor_locs).to(device)
	        gt_anchor_labels = torch.from_numpy(gt_anchor_labels).long().to(device)
	        # rpn_scores[0]表示類別標籤, rpn_scores[1]表示置信度
	        rpn_cls_loss = F.cross_entropy(rpn_scores[0], gt_anchor_labels, ignore_index=-1)  # label值爲-1的不參與loss值的計算
	        rpn_loc_loss = loc_loss(rpn_locs[0], gt_anchor_locs, gt_anchor_labels, self.rpn_sigma)
	
	        # ROI LOSS
	        gt_roi_labels = torch.from_numpy(gt_roi_labels).long().to(device)
	        gt_roi_locs = torch.from_numpy(gt_roi_locs).float().to(device)
	        roi_cls_loss = F.cross_entropy(roi_scores, gt_roi_labels)
	        n_sample = roi_locs.shape[0]  # batch_size
	        roi_cls_locs = roi_locs.view(n_sample, -1, 4)
	        roi_locs = roi_cls_locs[torch.arange(0, n_sample).long(), gt_roi_labels]
	        roi_loc_loss = loc_loss(roi_locs.contiguous(), gt_roi_locs, gt_roi_labels, self.roi_sigma)
	
	        losses = [rpn_loc_loss, rpn_cls_loss, roi_loc_loss, roi_cls_loss]
	        losses = losses + [sum(losses)]
	
	        return LossTuple(*losses)
	
	    @torch.no_grad()
	    def predict(self, x):
	        # 設置爲測試模式, 改變rpn網絡中n_post_nms的閾值爲300
	        self.eval()
	
	        # -----------------part 1: feature 提取部分----------------------
	        h = self.extractor(x)
	        img_size = (x.size(2), x.size(3))
	
	        # ----------------------part 2: rpn部分--------------------------
	        rpn_locs, rpn_socres, anchors, rois = self.rpn(h, img_size)
	
	        # ------------------part 3: fast rcnn(roi)部分-------------------
	        # 先經過Roi pooling層, 在經過兩個全連接層
	        roi_locs, roi_scores = self.fast_rcnn(h, np.asarray(rois))
	        n_sample = roi_locs.shape[0]
	
	        # --------------------part 4:boxes生成部分-----------------------
	        roi_cls_locs = roi_locs.view(n_sample, -1, 4)
	        rois = torch.from_numpy(rois).to(device)
	        rois = rois.view(-1, 1, 4).expand_as(roi_cls_locs)
	        boxes = loc2box(rois.cpu().numpy().reshape((-1, 4)), roi_cls_locs.cpu().numpy().reshape((-1, 4)))
	        boxes = torch.from_numpy(boxes).to(device)
	        # 修剪boxes中的座標, 使其落在圖片內
	        boxes[:, [0, 2]] = (boxes[:, [0, 2]]).clamp(min=0, max=img_size[0])
	        boxes[:, [1, 3]] = (boxes[:, [1, 3]]).clamp(min=0, max=img_size[1])
	        boxes = boxes.view(n_sample, -1)
	
	        # roi_scores轉換爲概率, prob維度爲[rois_num, 7]
	        prob = F.softmax(roi_scores, dim=1)
	
	        # ----------------part 5:篩選環節------------------------
	        raw_boxes = boxes.cpu().numpy()
	        raw_prob = prob.cpu().numpy()
	        final_boxes, labels, scores = self._suppress(raw_boxes, raw_prob)
	        self.train()
	        return final_boxes, labels, scores
	
	    def _suppress(self, raw_boxes, raw_prob):
	        # print(raw_prob.shape)
	        score_thresh = 0.7
	        nms_thresh = 0.3
	        n_class = class_num
	        box = list()
	        label = list()
	        score = list()
	
	        for i in range(1, class_num):
	            box_i = raw_boxes.reshape((-1, n_class, 4))
	            box_i = box_i[:, i, :]  # 維度爲: [rois_num, k, 4]
	            prob_i = raw_prob[:, i]  # 維度爲: [rois_num]
	            mask = prob_i > score_thresh
	            box_i = box_i[mask]
	            prob_i = prob_i[mask]
	            order = prob_i.argsort()[::-1]
	            # 按照score值從大到小進行排序
	            box_i = box_i[order]
	
	            box_i_after_nms, keep = non_maximum_suppression(box_i, nms_thresh)
	            box.append(box_i_after_nms)
	
	            label_i = (i - 1) * np.ones((len(keep),))
	            label.append(label_i)
	            score.append(prob_i[keep])
	
	        box = np.concatenate(box, axis=0).astype(np.float32)
	        label = np.concatenate(label, axis=0).astype(np.int32)
	        score = np.concatenate(score, axis=0).astype(np.float32)
	        return box, label, score





三、數據集部分

1、生成txt文件

       寫了兩種生成txt文件的代碼。


① 數據集給的是txt標註

       第一種基於的前提是數據集給的是txt標註,那可以用下面的函數生成4個txt並生成對應的xml文件。

	from lxml import etree as ET
	import glob
	import cv2
	import random
	from configs.config import classes_for_label, xml_root_dir, img_root_dir, txt_root_dir, pic_format
	import numpy as np
	from PIL import Image

	def write_xml(filename, saveimg, typename, boxes, xmlpath):
	    """
	    function description: 將txt的標註文件轉爲xml
	
	    :param filename: 圖片名
	    :param saveimg: opencv讀取圖片張量
	    :param typename: 類名
	    :param boxes: 左上角和右下角座標
	    :param xmlpath: 保存的xml文件名
	    """
	    # 根節點
	    root = ET.Element("annotation")
	
	    # folder節點
	    folder_node = ET.SubElement(root, 'folder')
	    folder_node.text = 'kitti'
	
	    # filename節點
	    filename_node = ET.SubElement(root, 'filename')
	    filename_node.text = filename
	
	    # source節點
	    source_node = ET.SubElement(root, 'source')
	    database_node = ET.SubElement(source_node, 'database')
	    database_node.text = 'kitti Database'
	    annotation_node = ET.SubElement(source_node, 'annotation')
	    annotation_node.text = 'kitti'
	    image_node = ET.SubElement(source_node, 'image')
	    image_node.text = 'flickr'
	    flickrid_node = ET.SubElement(source_node, 'flickrid')
	    flickrid_node.text = '-1'
	
	    # owner節點
	    owner_node = ET.SubElement(root, 'owner')
	    flickrid_node = ET.SubElement(owner_node, 'flickrid')
	    flickrid_node.text = 'muke'
	    name_node = ET.SubElement(owner_node, 'name')
	    name_node.text = 'muke'
	
	    # size節點
	    size_node = ET.SubElement(root, 'size')
	    width_node = ET.SubElement(size_node, 'width')
	    width_node.text = str(saveimg.shape[1])
	    height_node = ET.SubElement(size_node, 'height')
	    height_node.text = str(saveimg.shape[0])
	    depth_node = ET.SubElement(size_node, 'depth')
	    depth_node.text = str(saveimg.shape[2])
	
	    # segmented節點(用於圖像分割)
	    segmented_node = ET.SubElement(root, 'segmented')
	    segmented_node.text = '0'
	
	    # object節點(循環添加節點)
	    for i in range(len(typename)):
	        object_node = ET.SubElement(root, 'object')
	        name_node = ET.SubElement(object_node, 'name')
	        name_node.text = typename[i]
	        pose_node = ET.SubElement(object_node, 'pose')
	        pose_node.text = 'Unspecified'
	        # 是否截斷
	        truncated_node = ET.SubElement(object_node, 'truncated')
	        truncated_node.text = '1'
	        difficult_node = ET.SubElement(object_node, 'difficult')
	        difficult_node.text = '0'
	        bndbox_node = ET.SubElement(object_node, 'bndbox')
	        xmin_node = ET.SubElement(bndbox_node, 'xmin')
	        xmin_node.text = str(boxes[i][0])
	        ymin_node = ET.SubElement(bndbox_node, 'ymin')
	        ymin_node.text = str(boxes[i][1])
	        xmax_node = ET.SubElement(bndbox_node, 'xmax')
	        xmax_node.text = str(boxes[i][2])
	        ymax_node = ET.SubElement(bndbox_node, 'ymax')
	        ymax_node.text = str(boxes[i][3])
	
	    tree = ET.ElementTree(root)
	    tree.write(xmlpath, pretty_print=True)
	
	def split_dataset_byTXT():
	    """
	    function description: 根據總訓練集標註的txt文件將其數據集切分爲訓練集, 驗證集以及測試集, 並且寫入相應的xml作爲標註
	    """
	    trainval = open(txt_root_dir + 'trainval.txt', 'w')
	    train = open(txt_root_dir + 'train.txt', 'w')
	    val = open(txt_root_dir + 'val.txt', 'w')
	    test = open(txt_root_dir + 'train_test.txt', 'w')
	
	    list_anno_files = glob.glob(train_label_path + "*")
	    random.shuffle(list_anno_files)
	    index = 0
	    for anno_file in list_anno_files:
	        with open(anno_file) as file:
	            boxes = []
	            typename = []
	
	            anno_infos = file.readlines()
	            for anno_item in anno_infos:
	                anno_new_infos = anno_item.split(" ")
	                # 去掉雜項和不關心這倆類別
	                if anno_new_infos[0] == "Misc" or anno_new_infos[0] == "DontCare":
	                    continue
	                else:
	                    box = (int(float(anno_new_infos[4])), int(float(anno_new_infos[5])),
	                           int(float(anno_new_infos[6])), int(float(anno_new_infos[7])))
	                    boxes.append(box)
	                    typename.append(anno_new_infos[0])
	
	            filename = anno_file.split("\\")[-1].replace(".txt", pic_format)
	            xmlpath = xml_root_dir + filename.replace(pic_format, ".xml")
	            imgpath = img_root_dir + 'training/' + filename
	            print(imgpath)
	            saveimg = cv2.imread(imgpath)
	            write_xml(filename, saveimg, typename, boxes, xmlpath)
	
	            index += 1
	            if index > len(list_anno_files) * 0.9:
	                test.write(filename.replace(pic_format, "\n"))
	            else:
	                trainval.write(filename.replace(pic_format, "\n"))
	                if index > len(list_anno_files) * 0.7:
	                    val.write(filename.replace(pic_format, "\n"))
	                else:
	                    train.write(filename.replace(pic_format, "\n"))
	
	    trainval.close()
	    train.close()
	    val.close()
	    test.close()

② 數據集的標註直接是xml文件

       第二種基於的前提是數據集的標註直接是xml文件,那直接根據文件名生成txt文件就OK了。

	def split_dataset_byXML():
	    """
	    function description: 根據總訓練集的XML標註文件將其切分爲訓練集, 驗證集以及測試集
	    """
	    trainval = open(txt_root_dir + 'trainval.txt', 'w')
	    train = open(txt_root_dir + 'train.txt', 'w')
	    val = open(txt_root_dir + 'val.txt', 'w')
	    train_test = open(txt_root_dir + 'train_test.txt', 'w')
	
	    list_anno_files = glob.glob(xml_root_dir + "*")
	    random.shuffle(list_anno_files)
	    index = 0
	    for anno_file in list_anno_files:
	        filename = anno_file.replace(".xml", pic_format)
	        index += 1
	        if index > len(list_anno_files) * 0.9:
	            train_test.write(filename.replace(pic_format, "\n"))
	        else:
	            trainval.write(filename.replace(pic_format, "\n"))
	            if index > len(list_anno_files) * 0.7:
	                val.write(filename.replace(pic_format, "\n"))
	            else:
	                train.write(filename.replace(pic_format, "\n"))
	
	    trainval.close()
	    train.close()
	    val.close()
	    train_test.close()



2、實現Dataset的類

       因爲考慮到真正的測試集是沒有標註這麼一說的,所以__getitem__函數返回的內容也應該不是一樣的。對於測試集和訓練集想要最大程度地複用代碼,在所以在構造函數裏面傳了一個標記位,用來區分是train還是test。而且Faster-RCNN訓練所需要圖片的尺寸是有要求的,最小的邊必須超過600px,否則在Roi pooling的時候會出現問題,但是預測不準確,所以我在代碼裏面還是用了reshape函數。

       注意:對於訓練集直接reshape那就大錯特錯了,你需要在縮放圖片的同時等比例縮放標註框的位置!下面給的代碼都實現了,都是直接將張量和標註拉到內存,所以佔用的內存空間會很大。

	from torch.utils.data import Dataset, DataLoader
	from torchvision import transforms
	import os
	from data.process_data import parse_xml, reshape
	import numpy as np
	from PIL import Image
	from configs.config import pic_format
	
	
	class ImageDataset(Dataset):
	    def __init__(self, xml_root_dir, img_root_dir, txt_root_dir, txt_file, isTest=False, transform=None):
	        """
	        class description: 這個類已經將最小邊縮放到600px了, 同時將訓練集中標註的位置也等比例修改了
	
	        :param xml_root_dir: xml標註文件的根路徑
	        :param img_root_dir: img圖片的根路徑
	        :param txt_root_dir: txt文件的根路徑
	        :param txt_file: txt文件名
	        :param isTest: 標誌是否是測試集
	        :param transform: 變換
	        """
	        super(ImageDataset, self).__init__()
	
	        self.xml_root_dir = xml_root_dir
	        self.img_root_dir = img_root_dir
	        self.txt_root_dir = txt_root_dir
	        self.txt_file = txt_file
	        self.isTest = isTest
	        if transform == None:
	            self.transform = transforms.Compose([
	                # TODO BUG的根源... 爲了適配vgg16的輸入
	                # transforms.Resize((int(224), int(224))),
	                transforms.ToTensor(),
	                transforms.Normalize(
	                    [0.485, 0.456, 0.406],
	                    [0.229, 0.224, 0.225])
	            ])
	        if self.isTest == False:
	            boxes, labels, images = self.load_txt(self.txt_file)
	            self.boxes = boxes
	            self.labels = labels
	            self.images = images
	        elif self.isTest == True:
	            self.images = self.load_txt(self.txt_file)
	
	        id_list_files = os.path.join(txt_root_dir, txt_file)
	        self.ids = [id_.strip() for id_ in open(id_list_files)]
	
	    def load_txt(self, filename):
	        """
	        function description: 加載txt文件中的信息並放到numpy數組中, numpy可以直接在list中再次添加可變list
	
	        :param filename: txt文件名
	        """
	        print('-------------the file name is ', filename)
	        boxes = []
	        labels = []
	        images = []
	        print(os.path.join(self.txt_root_dir, filename))
	        with open(os.path.join(self.txt_root_dir, filename), mode='r') as f:
	            lines = f.readlines()
	            # index = 0
	            for line in lines:
	                line = line.strip()
	                if self.isTest == False:
	                    box, label, image = self.load_xml(line + ".xml")
	                    boxes.append(box)
	                    labels.append(label)
	                    # index += 1
	                elif self.isTest == True:
	                    image = (line + pic_format)
	                    # image = line.replace("\n", ".jpg")
	                images.append(image)
	
	        if self.isTest == False:
	            print('the length of boxes is ', len(boxes))
	            print('the length of labels is ', len(labels))
	            print('the length of images is ', len(images))
	            return boxes, labels, images
	        elif self.isTest == True:
	            return images
	
	    def load_xml(self, filename):
	        """
	        function description: 加載xml文件中需要的屬性並將最小邊縮放爲600
	
	        :param filename: xml文件名
	        """
	        path = os.path.join(self.xml_root_dir, filename)
	        if not os.path.exists(path):
	            return
	
	        boxes, labels = parse_xml(path)
	        img_name = filename.replace(".xml", pic_format)
	        images, boxes = reshape(Image.open(self.img_root_dir + img_name), boxes)
	        return np.stack(boxes).astype(np.float32), \
	               np.stack(labels).astype(np.int32), \
	               images
	
	    def __len__(self):
	        return len(self.images)
	
	    def __getitem__(self, index):
	        if self.isTest == False:
	            id = self.ids[index]
	            box, label, image = self.load_xml('{0}.xml'.format(id))
	            img_tensor = self.transform(image)
	            # [channel, height, width] -> [channel, width, height]
	            img_tensor = img_tensor.permute(0, 2, 1)
	            return {
	                "img_name": id + pic_format,
	                "img_tensor": img_tensor,
	                "img_classes": label,
	                "img_gt_boxes": box
	            }
	        elif self.isTest == True:
	            img = Image.open(self.img_root_dir + self.images[index])
	            img_tensor = self.transform(img)
	            img_tensor = img_tensor.permute(0, 2, 1)
	            return {
	                "img_name": self.images[index],
	                "img_tensor": img_tensor,
	            }




四、測試部分

       下圖是我kitti數據集在我代碼上面跑了一個epoch之後進行預測的結果。

Alt

Alt

效果還是不錯的。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章