Pytorch|YOWO原理及代碼詳解(一)
閱前可看:YOWO論文翻譯
YOWO很有趣,使用價值很大,最近剛好需要,所以就研究一下。一直認爲只有把源碼看懂,才知道諸多細節,纔算真正瞭解一個算法。筆者能力有限,博文若有出錯,歡迎指正交流。
這次爲了方便debug,所以就稍微改動了train.py 文件,修改爲myTrain.py,代碼分析就從這裏開始,但在之前需要完成各項配置。
1.訓練之前需要的工作。
1.1 ucf101-24數據集
ucf101-24數據集下載。論文使用了兩個數據集,本次代碼分析只使用ucf24數據集。
1.2 基礎骨幹網絡預訓練模型
有兩個,第一個是2d網絡yolov2。還有一個是3d網絡ResNeXt ve ResNet。本次代碼分析使用:“resnext-101-kinetics.pth”。
1.3 YOWO網絡預訓練模型
作者放百度雲了,密碼:95mm。
1.4 路徑配置
基礎骨幹網絡放到“weight”中,ucf24路徑隨意,但記得需要在ucf24.data中進行修改,如下:
2. 準備開始訓練
首先附上myTrain.py的完整代碼:
from __future__ import print_function
import sys
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.backends.cudnn as cudnn
from torchvision import datasets, transforms
import dataset
import random
import math
import os
from opts import parse_opts
from utils import *
from cfg import parse_cfg
from region_loss import RegionLoss
from model import YOWO, get_fine_tuning_parameters
import argparse
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--dataset", type=str, default="ucf101-24", help="dataset")
parser.add_argument("--data_cfg", type=str, default="cfg/ucf24.data ", help="data_cfg")
parser.add_argument("--cfg_file", type=str, default="cfg/ucf24.cfg ", help="cfg_file")
parser.add_argument("--n_classes", type=int, default=24, help="n_classes")
parser.add_argument("--backbone_3d", type=str, default="resnext101", help="backbone_3d")
parser.add_argument("--backbone_3d_weights", type=str, default="weights/resnext-101-kinetics.pth", help="backbone_3d_weights")
parser.add_argument("--backbone_2d", type=str, default="darknet", help="backbone_3d_weights")
parser.add_argument("--backbone_2d_weights", type=str, default="weights/yolo.weights", help="backbone_2d_weights")
parser.add_argument("--freeze_backbone_2d", type=bool, default=True, help="freeze_backbone_2d")
parser.add_argument("--freeze_backbone_3d", type=bool, default=True, help="freeze_backbone_3d")
parser.add_argument("--evaluate", type=bool, default=False, help="evaluate")
parser.add_argument("--begin_epoch", type=int, default=0, help="begin_epoch")
parser.add_argument("--end_epoch", type=int, default=4, help="evaluate")
opt = parser.parse_args()
# opt = parse_opts()
# which dataset to use
dataset_use = opt.dataset
assert dataset_use == 'ucf101-24' or dataset_use == 'jhmdb-21', 'invalid dataset'
# path for dataset of training and validation
datacfg = opt.data_cfg
# path for cfg file
cfgfile = opt.cfg_file
data_options = read_data_cfg(datacfg)
net_options = parse_cfg(cfgfile)[0]
# obtain list for training and testing
basepath = data_options['base']
trainlist = data_options['train']
testlist = data_options['valid']
backupdir = data_options['backup']
# number of training samples
nsamples = file_lines(trainlist)
gpus = data_options['gpus'] # e.g. 0,1,2,3
ngpus = len(gpus.split(','))
num_workers = int(data_options['num_workers'])
batch_size = int(net_options['batch'])
clip_duration = int(net_options['clip_duration'])
max_batches = int(net_options['max_batches'])
learning_rate = float(net_options['learning_rate'])
momentum = float(net_options['momentum'])
decay = float(net_options['decay'])
steps = [float(step) for step in net_options['steps'].split(',')]
scales = [float(scale) for scale in net_options['scales'].split(',')]
# loss parameters
loss_options = parse_cfg(cfgfile)[1]
region_loss = RegionLoss()
anchors = loss_options['anchors'].split(',')
region_loss.anchors = [float(i) for i in anchors]
region_loss.num_classes = int(loss_options['classes'])
region_loss.num_anchors = int(loss_options['num'])
region_loss.anchor_step = len(region_loss.anchors) // region_loss.num_anchors
region_loss.object_scale = float(loss_options['object_scale'])
region_loss.noobject_scale = float(loss_options['noobject_scale'])
region_loss.class_scale = float(loss_options['class_scale'])
region_loss.coord_scale = float(loss_options['coord_scale'])
region_loss.batch = batch_size
# Train parameters
max_epochs = max_batches * batch_size // nsamples + 1
use_cuda = True
seed = int(time.time())
eps = 1e-5
best_fscore = 0 # initialize best fscore
# Test parameters
nms_thresh = 0.4
iou_thresh = 0.5
if not os.path.exists(backupdir):
os.mkdir(backupdir)
# 設置隨機種子
torch.manual_seed(seed)
if use_cuda:
os.environ['CUDA_VISIBLE_DEVICES'] = gpus
torch.cuda.manual_seed(seed)
# Create model
model = YOWO(opt)
model = model.cuda()
model = nn.DataParallel(model, device_ids=None) # in multi-gpu case
model.seen = 0
print(model)
parameters = get_fine_tuning_parameters(model, opt)
optimizer = optim.SGD(parameters, lr=learning_rate / batch_size, momentum=momentum, dampening=0,
weight_decay=decay * batch_size)
kwargs = {'num_workers': num_workers, 'pin_memory': True} if use_cuda else {}
# Load resume path if necessary
# if opt.resume_path:
# print("===================================================================")
# print('loading checkpoint {}'.format(opt.resume_path))
# checkpoint = torch.load(opt.resume_path)
# opt.begin_epoch = checkpoint['epoch']
# best_fscore = checkpoint['fscore']
# model.load_state_dict(checkpoint['state_dict'])
# optimizer.load_state_dict(checkpoint['optimizer'])
# model.seen = checkpoint['epoch'] * nsamples
# print("Loaded model fscore: ", checkpoint['fscore'])
# print("===================================================================")
region_loss.seen = model.seen
processed_batches = model.seen // batch_size
init_width = int(net_options['width'])
init_height = int(net_options['height'])
init_epoch = model.seen // nsamples
def adjust_learning_rate(optimizer, batch):
lr = learning_rate
for i in range(len(steps)):
scale = scales[i] if i < len(scales) else 1
if batch >= steps[i]:
lr = lr * scale
if batch == steps[i]:
break
else:
break
for param_group in optimizer.param_groups:
param_group['lr'] = lr / batch_size
return lr
def train(epoch):
global processed_batches
t0 = time.time()
cur_model = model.module
region_loss.l_x.reset()
region_loss.l_y.reset()
region_loss.l_w.reset()
region_loss.l_h.reset()
region_loss.l_conf.reset()
region_loss.l_cls.reset()
region_loss.l_total.reset()
train_loader = torch.utils.data.DataLoader(
dataset.listDataset(basepath, trainlist, dataset_use=dataset_use, shape=(init_width, init_height),
shuffle=True,
transform=transforms.Compose([
transforms.ToTensor(),
]),
train=True,
seen=cur_model.seen,
batch_size=batch_size,
clip_duration=clip_duration,
num_workers=num_workers),
batch_size=batch_size, shuffle=False, **kwargs)
lr = adjust_learning_rate(optimizer, processed_batches)
logging('training at epoch %d, lr %f' % (epoch, lr))
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
adjust_learning_rate(optimizer, processed_batches)
processed_batches = processed_batches + 1
if use_cuda:
data = data.cuda()
optimizer.zero_grad()
output = model(data)
region_loss.seen = region_loss.seen + data.data.size(0)
loss = region_loss(output, target)
loss.backward()
optimizer.step()
# save result every 1000 batches
if processed_batches % 500 == 0: # From time to time, reset averagemeters to see improvements
region_loss.l_x.reset()
region_loss.l_y.reset()
region_loss.l_w.reset()
region_loss.l_h.reset()
region_loss.l_conf.reset()
region_loss.l_cls.reset()
region_loss.l_total.reset()
t1 = time.time()
logging('trained with %f samples/s' % (len(train_loader.dataset) / (t1 - t0)))
print('')
def test(epoch):
def truths_length(truths):
for i in range(50):
if truths[i][1] == 0:
return i
test_loader = torch.utils.data.DataLoader(
dataset.listDataset(basepath, testlist, dataset_use=dataset_use, shape=(init_width, init_height),
shuffle=False,
transform=transforms.Compose([
transforms.ToTensor()
]), train=False),
batch_size=batch_size, shuffle=False, **kwargs)
num_classes = region_loss.num_classes
anchors = region_loss.anchors
num_anchors = region_loss.num_anchors
conf_thresh_valid = 0.005
total = 0.0
proposals = 0.0
correct = 0.0
fscore = 0.0
correct_classification = 0.0
total_detected = 0.0
nbatch = file_lines(testlist) // batch_size
logging('validation at epoch %d' % (epoch))
model.eval()
for batch_idx, (frame_idx, data, target) in enumerate(test_loader):
if use_cuda:
data = data.cuda()
with torch.no_grad():
output = model(data).data
all_boxes = get_region_boxes(output, conf_thresh_valid, num_classes, anchors, num_anchors, 0, 1)
for i in range(output.size(0)):
boxes = all_boxes[i]
boxes = nms(boxes, nms_thresh)
if dataset_use == 'ucf101-24':
detection_path = os.path.join('ucf_detections', 'detections_' + str(epoch), frame_idx[i])
current_dir = os.path.join('ucf_detections', 'detections_' + str(epoch))
if not os.path.exists('ucf_detections'):
os.mkdir(current_dir)
if not os.path.exists(current_dir):
os.mkdir(current_dir)
else:
detection_path = os.path.join('jhmdb_detections', 'detections_' + str(epoch), frame_idx[i])
current_dir = os.path.join('jhmdb_detections', 'detections_' + str(epoch))
if not os.path.exists('jhmdb_detections'):
os.mkdir(current_dir)
if not os.path.exists(current_dir):
os.mkdir(current_dir)
with open(detection_path, 'w+') as f_detect:
for box in boxes:
x1 = round(float(box[0] - box[2] / 2.0) * 320.0)
y1 = round(float(box[1] - box[3] / 2.0) * 240.0)
x2 = round(float(box[0] + box[2] / 2.0) * 320.0)
y2 = round(float(box[1] + box[3] / 2.0) * 240.0)
det_conf = float(box[4])
for j in range((len(box) - 5) // 2):
cls_conf = float(box[5 + 2 * j].item())
if type(box[6 + 2 * j]) == torch.Tensor:
cls_id = int(box[6 + 2 * j].item())
else:
cls_id = int(box[6 + 2 * j])
prob = det_conf * cls_conf
f_detect.write(
str(int(box[6]) + 1) + ' ' + str(prob) + ' ' + str(x1) + ' ' + str(y1) + ' ' + str(
x2) + ' ' + str(y2) + '\n')
truths = target[i].view(-1, 5)
num_gts = truths_length(truths)
total = total + num_gts
for i in range(len(boxes)):
if boxes[i][4] > 0.25:
proposals = proposals + 1
for i in range(num_gts):
box_gt = [truths[i][1], truths[i][2], truths[i][3], truths[i][4], 1.0, 1.0, truths[i][0]]
best_iou = 0
best_j = -1
for j in range(len(boxes)):
iou = bbox_iou(box_gt, boxes[j], x1y1x2y2=False)
if iou > best_iou:
best_j = j
best_iou = iou
if best_iou > iou_thresh:
total_detected += 1
if int(boxes[best_j][6]) == box_gt[6]:
correct_classification += 1
if best_iou > iou_thresh and int(boxes[best_j][6]) == box_gt[6]:
correct = correct + 1
precision = 1.0 * correct / (proposals + eps)
recall = 1.0 * correct / (total + eps)
fscore = 2.0 * precision * recall / (precision + recall + eps)
logging(
"[%d/%d] precision: %f, recall: %f, fscore: %f" % (batch_idx, nbatch, precision, recall, fscore))
classification_accuracy = 1.0 * correct_classification / (total_detected + eps)
locolization_recall = 1.0 * total_detected / (total + eps)
print("Classification accuracy: %.3f" % classification_accuracy)
print("Locolization recall: %.3f" % locolization_recall)
return fscore
if opt.evaluate:
logging('evaluating ...')
test(0)
else:
for epoch in range(opt.begin_epoch, opt.end_epoch + 1):
# Train the model for 1 epoch
train(epoch)
# Validate the model
fscore = test(epoch)
is_best = fscore > best_fscore
if is_best:
print("New best fscore is achieved: ", fscore)
print("Previous fscore was: ", best_fscore)
best_fscore = fscore
# Save the model to backup directory
state = {
'epoch': epoch,
'state_dict': model.state_dict(),
'optimizer': optimizer.state_dict(),
'fscore': fscore
}
save_checkpoint(state, is_best, backupdir, opt.dataset, clip_duration)
logging('Weights are saved to backup directory: %s' % (backupdir))
2.1 基本設置
parser = argparse.ArgumentParser()
......
# path for dataset of training and validation
datacfg = opt.data_cfg
# path for cfg file
cfgfile = opt.cfg_file
配置數據集,cfg路徑,選擇的基礎網絡是什麼,是否使用評估模式等。
2.2 解析data/cfg文件
data_options = read_data_cfg(datacfg)
net_options = parse_cfg(cfgfile)[0]
首先看read_data_cfg(datacfg)
完整代碼如下:
def read_data_cfg(datacfg):
options = dict()
options['gpus'] = '0'
options['num_workers'] = '0'
with open(datacfg, 'r') as fp:
lines = fp.readlines()
for line in lines:
line = line.strip()
if line == '':
continue
key,value = line.split('=')
key = key.strip()
value = value.strip()
options[key] = value
return options
該段代碼解析cfg/ucf24.data
,獲取訓練、測試集的路徑。
查看解析結果如下:
代碼:net_options = parse_cfg(cfgfile)[0]
,是解析網絡配置文件。解析的文件是cfg/ucf24.cfg
完整代碼如下:
def parse_cfg(cfgfile):
blocks = []
fp = open(cfgfile, 'r')
block = None
line = fp.readline()
while line != '':
line = line.rstrip()
if line == '' or line[0] == '#':
line = fp.readline()
continue
elif line[0] == '[':
if block:
blocks.append(block)
block = dict()
block['type'] = line.lstrip('[').rstrip(']')
# set default value
if block['type'] == 'convolutional':
block['batch_normalize'] = 0
else:
key,value = line.split('=')
key = key.strip()
if key == 'type':
key = '_type'
value = value.strip()
block[key] = value
line = fp.readline()
if block:
blocks.append(block)
fp.close()
return blocks
得到的結果如下:
這個解析應該分爲兩個部分。
- 第一部分是整個網絡的訓練配置,不如數據的尺寸、學習率大小,學習率衰減策略等。
- 第二部分主要是yolov2的配置,因爲
type='region'
。剩下的anchor是預設的尺寸,一共有10個尺寸,5個anchor,對應num=5
。剩下的object_scale、noobject_scale、class_scale以及coord_scale應該是損失函數的懲罰因子,待loss函數處進行驗證。
根據blocks是列表,可得知net_options 獲取的是整個網絡的訓練配置。
2.3 獲取訓練和測試時的配置列表
basepath = data_options['base']
trainlist = data_options['train']
testlist = data_options['valid']
backupdir = data_options['backup']
# number of training samples
nsamples = file_lines(trainlist)
gpus = data_options['gpus'] # e.g. 0,1,2,3
ngpus = len(gpus.split(','))
num_workers = int(data_options['num_workers'])
batch_size = int(net_options['batch'])
clip_duration = int(net_options['clip_duration'])
max_batches = int(net_options['max_batches'])
learning_rate = float(net_options['learning_rate'])
momentum = float(net_options['momentum'])
decay = float(net_options['decay'])
steps = [float(step) for step in net_options['steps'].split(',')]
scales = [float(scale) for scale in net_options['scales'].split(',')]
這一部分則是把上面解析cfg/data的內容分別保存下來,從而進行後續的訓練、測試。
2.4 損失函數的各項參數
# loss parameters
loss_options = parse_cfg(cfgfile)[1]
region_loss = RegionLoss()
...
region_loss.batch = batch_size
這裏主要分析代碼region_loss = RegionLoss()
。RegionLoss是在region_loss.py中的一共類,完整如下。
class RegionLoss(nn.Module):
# for our model anchors has 10 values and number of anchors is 5
# 我們的模型錨點有10個值,錨點的數量是5個
# parameters: 24, 10 float values, 24, 5
def __init__(self, num_classes=0, anchors=[], batch=16, num_anchors=1):
super(RegionLoss, self).__init__()
self.num_classes = num_classes
self.batch = batch
self.anchors = anchors
self.num_anchors = num_anchors
self.anchor_step = len(anchors)//num_anchors # each anchor has 2 parameters
self.coord_scale = 1
self.noobject_scale = 1
self.object_scale = 5
self.class_scale = 1
self.thresh = 0.6
self.seen = 0
self.l_x = AverageMeter()
self.l_y = AverageMeter()
self.l_w = AverageMeter()
self.l_h = AverageMeter()
self.l_conf = AverageMeter()
self.l_cls = AverageMeter()
self.l_total = AverageMeter()
def forward(self, output, target):
# output : B*A*(4+1+num_classes)*H*W
......
return loss
這裏主要分析其初始化__init__
,其forward
在計算loss時再做講解。其中的各項參數都在上述部分基本講解過。self.seen = 0
暫時不知道上面意思,後面再做講解。AverageMeter是utils.py中的一個類,目的是計算平均值和存儲當前值,也是說loss中的x、y、w、h、conf、cls以及total會隨着計算不斷累加且求平均。
2.5 訓練/測試參數設置。
# Train parameters
max_epochs = max_batches * batch_size // nsamples + 1
use_cuda = True
seed = int(time.time())
eps = 1e-5
best_fscore = 0 # initialize best fscore
# Test parameters
nms_thresh = 0.4
iou_thresh = 0.5
if not os.path.exists(backupdir):
os.mkdir(backupdir)
# 設置隨機種子
torch.manual_seed(seed)
if use_cuda:
os.environ['CUDA_VISIBLE_DEVICES'] = gpus
torch.cuda.manual_seed(seed)
作者好像沒有用到max_epochs 這個參數。。。,除了設計隨機種子外,也把非極大值抑制(NMS)的相關參數給設置了。
2.6 加載模型設置優化器
# Create model
model = YOWO(opt)
model = model.cuda()
model = nn.DataParallel(model, device_ids=None) # in multi-gpu case
model.seen = 0
print(model)
parameters = get_fine_tuning_parameters(model, opt)
optimizer = optim.SGD(parameters, lr=learning_rate / batch_size, momentum=momentum, dampening=0,
weight_decay=decay * batch_size)
kwargs = {'num_workers': num_workers, 'pin_memory': True} if use_cuda else {}
# Load resume path if necessary
# if opt.resume_path:
# print("===================================================================")
# print('loading checkpoint {}'.format(opt.resume_path))
# checkpoint = torch.load(opt.resume_path)
# opt.begin_epoch = checkpoint['epoch']
# best_fscore = checkpoint['fscore']
# model.load_state_dict(checkpoint['state_dict'])
# optimizer.load_state_dict(checkpoint['optimizer'])
# model.seen = checkpoint['epoch'] * nsamples
# print("Loaded model fscore: ", checkpoint['fscore'])
# print("===================================================================")
region_loss.seen = model.seen
processed_batches = model.seen // batch_size
init_width = int(net_options['width'])
init_height = int(net_options['height'])
init_epoch = model.seen // nsamples
這裏最核心的是:
是否使用多GPU訓練。使用隨機梯度下降做優化器(SGD)。是否要中斷後重寫訓練if opt.resume_path:
。現在可以看出model.seen
是記錄當前訓練了多久,以方便終端後繼續訓練,其訓練的epoch、region_loss以及processed_batches能剛好銜接上。
另外則是model = YOWO(opt)
,創建YOWO模型,YOWO是在model.py中,查看:
class YOWO(nn.Module):
def __init__(self, opt):
super(YOWO, self).__init__()
self.opt = opt
##### 2D Backbone #####
if opt.backbone_2d == "darknet":
self.backbone_2d = darknet.Darknet("cfg/yolo.cfg")
num_ch_2d = 425 # Number of output channels for backbone_2d
else:
raise ValueError("Wrong backbone_2d model is requested. Please select\
it from [darknet]")
if opt.backbone_2d_weights:# load pretrained weights on COCO dataset
self.backbone_2d.load_weights(opt.backbone_2d_weights)
##### 3D Backbone #####
if opt.backbone_3d == "resnext101":
self.backbone_3d = resnext.resnext101()
num_ch_3d = 2048 # Number of output channels for backbone_3d
elif opt.backbone_3d == "resnet18":
self.backbone_3d = resnet.resnet18(shortcut_type='A')
num_ch_3d = 512 # Number of output channels for backbone_3d
elif opt.backbone_3d == "resnet50":
self.backbone_3d = resnet.resnet18(shortcut_type='B')
num_ch_3d = 2048 # Number of output channels for backbone_3d
elif opt.backbone_3d == "resnet101":
self.backbone_3d = resnet.resnet18(shortcut_type='B')
num_ch_3d = 2048 # Number of output channels for backbone_3d
elif opt.backbone_3d == "mobilenet_2x":
self.backbone_3d = mobilenet.get_model(width_mult=2.0)
num_ch_3d = 2048 # Number of output channels for backbone_3d
elif opt.backbone_3d == "mobilenetv2_1x":
self.backbone_3d = mobilenetv2.get_model(width_mult=1.0)
num_ch_3d = 1280 # Number of output channels for backbone_3d
elif opt.backbone_3d == "shufflenet_2x":
self.backbone_3d = shufflenet.get_model(groups=3, width_mult=2.0)
num_ch_3d = 1920 # Number of output channels for backbone_3d
elif opt.backbone_3d == "shufflenetv2_2x":
self.backbone_3d = shufflenetv2.get_model(width_mult=2.0)
num_ch_3d = 2048 # Number of output channels for backbone_3d
else:
raise ValueError("Wrong backbone_3d model is requested. Please select it from [resnext101, resnet101, \
resnet50, resnet18, mobilenet_2x, mobilenetv2_1x, shufflenet_2x, shufflenetv2_2x]")
if opt.backbone_3d_weights:# load pretrained weights on Kinetics-600 dataset
self.backbone_3d = self.backbone_3d.cuda()
self.backbone_3d = nn.DataParallel(self.backbone_3d, device_ids=None) # Because the pretrained backbone models are saved in Dataparalled mode
pretrained_3d_backbone = torch.load(opt.backbone_3d_weights)
backbone_3d_dict = self.backbone_3d.state_dict()
pretrained_3d_backbone_dict = {k: v for k, v in pretrained_3d_backbone['state_dict'].items() if k in backbone_3d_dict} # 1. filter out unnecessary keys
backbone_3d_dict.update(pretrained_3d_backbone_dict) # 2. overwrite entries in the existing state dict
self.backbone_3d.load_state_dict(backbone_3d_dict) # 3. load the new state dict
self.backbone_3d = self.backbone_3d.module # remove the dataparallel wrapper
##### Attention & Final Conv #####
self.cfam = CFAMBlock(num_ch_2d+num_ch_3d, 1024)
self.conv_final = nn.Conv2d(1024, 5*(opt.n_classes+4+1), kernel_size=1, bias=False)
self.seen = 0
def forward(self, input):
x_3d = input # Input clip
x_2d = input[:, :, -1, :, :] # Last frame of the clip that is read
x_2d = self.backbone_2d(x_2d)
x_3d = self.backbone_3d(x_3d)
x_3d = torch.squeeze(x_3d, dim=2)
x = torch.cat((x_3d, x_2d), dim=1)
x = self.cfam(x)
out = self.conv_final(x)
return out
YOWO由兩部分組成:2D網絡和3D網絡。如下圖所示:
只是需要注意每個網絡的輸出通道是多少。本例中使用的2d網絡是yolov2,3d網絡是resnext101。yolov2是負責檢測的,默認輸出425是因爲5 * (80 + 5)
。5個anchor,80個類。if opt.backbone_3d_weights
,如果有預訓練模型,則加載,這些3d預訓練模型是在Kinetics-600得到的。2d網絡和3d網絡都是可以任意組合的,這裏就不一一分析,接下來是論文提出的一個創新點:通道融合與注意機制。
通道融合與注意機制:self.cfam = CFAMBlock(num_ch_2d+num_ch_3d, 1024)
。CFAMBlock是在cfam.py中,完整代碼如下:
class CAM_Module(nn.Module):
""" Channel attention module """
def __init__(self, in_dim):
super(CAM_Module, self).__init__()
self.chanel_in = in_dim
self.gamma = nn.Parameter(torch.zeros(1))
self.softmax = nn.Softmax(dim=-1)
def forward(self,x):
"""
inputs :
x : input feature maps( B X C X H X W )
returns :
out : attention value + input feature
attention: B X C X C
"""
m_batchsize, C, height, width = x.size()
proj_query = x.view(m_batchsize, C, -1)
proj_key = x.view(m_batchsize, C, -1).permute(0, 2, 1)
energy = torch.bmm(proj_query, proj_key)
energy_new = torch.max(energy, -1, keepdim=True)[0].expand_as(energy)-energy
attention = self.softmax(energy_new)
proj_value = x.view(m_batchsize, C, -1)
out = torch.bmm(attention, proj_value)
out = out.view(m_batchsize, C, height, width)
out = self.gamma*out + x
return out
class CFAMBlock(nn.Module):
def __init__(self, in_channels, out_channels):
super(CFAMBlock, self).__init__()
inter_channels = 1024
self.conv_bn_relu1 = nn.Sequential(nn.Conv2d(in_channels, inter_channels, kernel_size=1, bias=False),
nn.BatchNorm2d(inter_channels),
nn.ReLU())
self.conv_bn_relu2 = nn.Sequential(nn.Conv2d(inter_channels, inter_channels, 3, padding=1, bias=False),
nn.BatchNorm2d(inter_channels),
nn.ReLU())
self.sc = CAM_Module(inter_channels)
self.conv_bn_relu3 = nn.Sequential(nn.Conv2d(inter_channels, inter_channels, 3, padding=1, bias=False),
nn.BatchNorm2d(inter_channels),
nn.ReLU())
self.conv_out = nn.Sequential(nn.Dropout2d(0.1, False), nn.Conv2d(inter_channels, out_channels, 1))
def forward(self, x):
x = self.conv_bn_relu1(x)
x = self.conv_bn_relu2(x)
x = self.sc(x)
x = self.conv_bn_relu3(x)
output = self.conv_out(x)
return output
關於CFAM的詳細內容可以看:YOWO論文翻譯。這裏使用論文中的結構圖,方便分析:
CFAMBlock由四層卷積層和1個CAM_Module組成。把2d和3d網絡的輸出按通道拼接而成作爲輸入,接着使用2個2d 卷積提取特徵,然後輸入到CAM_Module中,最後再使用2個2d 卷積得到最後的輸出,其shape的保持不變,而通道由變成了。
CFAMBlock中的前後兩個卷積層比較容易理解,現在看下CAM_Module,並和論文中的公式推導一一結合。
- 首先,把經過兩個卷積層的輸出重塑成,其中,即每個通道的特徵向量化爲一維:,對應代碼:
proj_query = x.view(m_batchsize, C, -1)
。 - 將和其轉置(
proj_key = x.view(m_batchsize, C, -1).permute(0, 2, 1)
)相乘可以得到格拉姆矩陣:,對應代碼:energy = torch.bmm(proj_query, proj_key)
。 - 關於
energy_new = torch.max(energy, -1, keepdim=True)[0].expand_as(energy)-energy
,我着實沒看懂,沒看到論文中沒有這個地方的介紹,若有人知道,望指教。 - 接下來就是計算格拉姆矩陣,使用softmax層生成通道注意圖,即:,對應:
attention = self.softmax(energy_new)
。 - 爲了實現注意力映射對原始特徵的影響,進一步進行與的矩陣乘法,將結果重新整形爲與輸入張量形狀相同的三維空間:即,對應代碼:
proj_value = x.view(m_batchsize, C, -1)
out = torch.bmm(attention, proj_value)
- 接着再重塑,即,對應:
out = out.view(m_batchsize, C, height, width)
。 - 通道注意力模塊的輸出將此結果與原始輸入特徵圖結合,並使用可訓練標量參數進行元素和運算,從0逐漸學習權重:,對應代碼:
out = self.gamma*out + x
,self.gamma被初始化爲0.
所以YOWO模型中的:
##### Attention & Final Conv #####
self.cfam = CFAMBlock(num_ch_2d+num_ch_3d, 1024)
self.conv_final = nn.Conv2d(1024, 5*(opt.n_classes+4+1), kernel_size=1, bias=False)
也不難理解,把2d和3d網絡按通道進行contac,最後在輸出張量的通道爲爲5*(opt.n_classes+4+1)
,即對應5個anchor對24(opt.n_classes)種行爲理解定位(x,y,w,h,conf)的結果。
代碼:parameters = get_fine_tuning_parameters(model, opt)
,是選擇是否進行微調,一般是對最後幾層進行微調,關於微調的更多理解可以看模型finetune。
現在基本工作分析完畢,下面就是開始訓練了,詳情請見:
Pytorch|YOWO原理及代碼詳解(二)