CVPR 2018 | Cascaded Pyramid Network for Multi-Person Pose Estimation
https://github.com/chenyilun95/tf-cpn
1.文章概述
本文提出了一種級聯金字塔網絡CPN,該網絡由全局金字塔網絡(GlobalNet)和利用在線難例挖掘機制的精餾網絡(RefineNet)組成。GlobalNet是一個特徵金字塔網絡,可以成功地定位“簡單”的關鍵點(如眼睛和手),但可能無法準確識別被遮擋或看不見的關鍵點。RefineNet嘗試通過整合來自GlobalNet的所有尺度的特徵,以及在線難例關鍵點挖掘損失來處理“複雜”關鍵點的精確定位。
如下圖所示,Cascaded Pyramid Network主要由兩部分組成:GlobalNet和RefineNet。
2.GlobalNet
如下圖所示,GlobalNet以ResNet爲基礎框架,使用與FPN相似的特徵金字塔結構來估計關鍵點。每一個特徵尺度多會輸出對應的關鍵點信息。作者稱這種結構爲GlobalNet。
基於ResNet主幹網的GlobalNet可以有效地定位眼睛等關鍵點,但可能無法準確定位髖部位置。像髖部這樣的關鍵點定位通常需要更多的上下文信息和處理,而不是附近的外觀特徵。在許多情況下,單憑一個Global網絡很難直接識別這些關鍵點。基於此作者在此後接了一個RefineNet。
3.RefineNet
如下圖所示,在GlobalNet生成的特徵金字塔表示的基礎上,作者附加了一個細化網絡來處理難例關鍵點。爲了提高信息傳輸的效率和保持信息傳輸的完整性,RefineNet將不同的層次的特徵進行上採樣後concat。與堆疊沙漏的細分策略不同,RefineNet將所有金字塔特性串聯起來,而不是簡單地使用沙漏模塊末尾的上採樣特性。
隨着網絡訓練的不斷深入,網絡對大多數簡單關鍵點的關注越來越多,而對被遮擋和硬關鍵點的關注越來越少。我們應該確保這兩類關鍵點之間的迴歸平衡。因此,在RefineNet訓練中,根據訓練損失來明確地在線選擇難例關鍵點,並僅從所選關鍵點反向傳播梯度,該方法被稱爲OHKM。如下代碼所示爲OHKM損失函數,從中可以看出該函數就是對MSE輸出的結果進行了排序,並篩選其中難例部分進行重點回歸。
class JointsOHKMMSELoss(nn.Module):
def __init__(self, use_target_weight, topk=8):
super(JointsOHKMMSELoss, self).__init__()
self.criterion = nn.MSELoss(reduction='none')
self.use_target_weight = use_target_weight
self.topk = topk
def ohkm(self, loss):
ohkm_loss = 0.
for i in range(loss.size()[0]):
sub_loss = loss[i]
topk_val, topk_idx = torch.topk(
sub_loss, k=self.topk, dim=0, sorted=False
)
tmp_loss = torch.gather(sub_loss, 0, topk_idx)
ohkm_loss += torch.sum(tmp_loss) / self.topk
ohkm_loss /= loss.size()[0]
return ohkm_loss
def forward(self, output, target, target_weight):
batch_size = output.size(0)
num_joints = output.size(1)
heatmaps_pred = output.reshape((batch_size, num_joints, -1)).split(1, 1)
heatmaps_gt = target.reshape((batch_size, num_joints, -1)).split(1, 1)
loss = []
for idx in range(num_joints):
heatmap_pred = heatmaps_pred[idx].squeeze()
heatmap_gt = heatmaps_gt[idx].squeeze()
if self.use_target_weight:
loss.append(0.5 * self.criterion(
heatmap_pred.mul(target_weight[:, idx]),
heatmap_gt.mul(target_weight[:, idx])
))
else:
loss.append(
0.5 * self.criterion(heatmap_pred, heatmap_gt)
)
loss = [l.mean(dim=1).unsqueeze(dim=1) for l in loss]
loss = torch.cat(loss, dim=1)
return self.ohkm(loss)
4.結果展示
下圖展示了不同閾值的NMS策略的性能,結果顯示Soft-NMS表現出了最優性能。
下圖結果顯示了OHKM,在線難例挖掘的有效性。
最終的結果也顯示了本文提出的策略的有效性,但總的來說本文提出的OHKM反而被其他SOTA算法廣泛使用。