CVPR 2018 | CPN_COCO2017姿態估計冠軍解決方案

CVPR 2018 | Cascaded Pyramid Network for Multi-Person Pose Estimation
https://github.com/chenyilun95/tf-cpn

1.文章概述

本文提出了一種級聯金字塔網絡CPN,該網絡由全局金字塔網絡(GlobalNet)和利用在線難例挖掘機制的精餾網絡(RefineNet)組成。GlobalNet是一個特徵金字塔網絡,可以成功地定位“簡單”的關鍵點(如眼睛和手),但可能無法準確識別被遮擋或看不見的關鍵點。RefineNet嘗試通過整合來自GlobalNet的所有尺度的特徵,以及在線難例關鍵點挖掘損失來處理“複雜”關鍵點的精確定位。

如下圖所示,Cascaded Pyramid Network主要由兩部分組成:GlobalNet和RefineNet。

2.GlobalNet

如下圖所示,GlobalNet以ResNet爲基礎框架,使用與FPN相似的特徵金字塔結構來估計關鍵點。每一個特徵尺度多會輸出對應的關鍵點信息。作者稱這種結構爲GlobalNet。

基於ResNet主幹網的GlobalNet可以有效地定位眼睛等關鍵點,但可能無法準確定位髖部位置。像髖部這樣的關鍵點定位通常需要更多的上下文信息和處理,而不是附近的外觀特徵。在許多情況下,單憑一個Global網絡很難直接識別這些關鍵點。基於此作者在此後接了一個RefineNet。

3.RefineNet

如下圖所示,在GlobalNet生成的特徵金字塔表示的基礎上,作者附加了一個細化網絡來處理難例關鍵點。爲了提高信息傳輸的效率和保持信息傳輸的完整性,RefineNet將不同的層次的特徵進行上採樣後concat。與堆疊沙漏的細分策略不同,RefineNet將所有金字塔特性串聯起來,而不是簡單地使用沙漏模塊末尾的上採樣特性。

隨着網絡訓練的不斷深入,網絡對大多數簡單關鍵點的關注越來越多,而對被遮擋和硬關鍵點的關注越來越少。我們應該確保這兩類關鍵點之間的迴歸平衡。因此,在RefineNet訓練中,根據訓練損失來明確地在線選擇難例關鍵點,並僅從所選關鍵點反向傳播梯度,該方法被稱爲OHKM。如下代碼所示爲OHKM損失函數,從中可以看出該函數就是對MSE輸出的結果進行了排序,並篩選其中難例部分進行重點回歸。

class JointsOHKMMSELoss(nn.Module):
    def __init__(self, use_target_weight, topk=8):
        super(JointsOHKMMSELoss, self).__init__()
        self.criterion = nn.MSELoss(reduction='none')
        self.use_target_weight = use_target_weight
        self.topk = topk

    def ohkm(self, loss):
        ohkm_loss = 0.
        for i in range(loss.size()[0]):
            sub_loss = loss[i]
            topk_val, topk_idx = torch.topk(
                sub_loss, k=self.topk, dim=0, sorted=False
            )
            tmp_loss = torch.gather(sub_loss, 0, topk_idx)
            ohkm_loss += torch.sum(tmp_loss) / self.topk
        ohkm_loss /= loss.size()[0]
        return ohkm_loss

    def forward(self, output, target, target_weight):
        batch_size = output.size(0)
        num_joints = output.size(1)
        heatmaps_pred = output.reshape((batch_size, num_joints, -1)).split(1, 1)
        heatmaps_gt = target.reshape((batch_size, num_joints, -1)).split(1, 1)

        loss = []
        for idx in range(num_joints):
            heatmap_pred = heatmaps_pred[idx].squeeze()
            heatmap_gt = heatmaps_gt[idx].squeeze()
            if self.use_target_weight:
                loss.append(0.5 * self.criterion(
                    heatmap_pred.mul(target_weight[:, idx]),
                    heatmap_gt.mul(target_weight[:, idx])
                ))
            else:
                loss.append(
                    0.5 * self.criterion(heatmap_pred, heatmap_gt)
                )

        loss = [l.mean(dim=1).unsqueeze(dim=1) for l in loss]
        loss = torch.cat(loss, dim=1)

        return self.ohkm(loss)
4.結果展示

下圖展示了不同閾值的NMS策略的性能,結果顯示Soft-NMS表現出了最優性能。

下圖結果顯示了OHKM,在線難例挖掘的有效性。

最終的結果也顯示了本文提出的策略的有效性,但總的來說本文提出的OHKM反而被其他SOTA算法廣泛使用。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章