CenterLoss | 減小類間距離

1.centerloss原理

centerloss中心損失它僅僅用來減少類內的差異,而不能有效增大類間的差異性。下圖中,圖(a)表示softmax loss學習到的特徵描述 。圖(b)表示softmax loss + center loss 學習到的特徵描述,他能把同一類的樣本之間的距離拉近一些,使其相似性變大,儘量的往樣本中心靠攏,但可以看出他沒有把不同類樣本之間的樣本距離拉大。

centerloss的主要思路爲:讓每一類特徵儘可能的在輸出特徵空間內聚集在一起。更直白的描述就是每一類的特徵在特徵空間中儘可能的聚集在某一箇中心點附近。正常情況下,如果我們先驗的知道了所有樣本的GT中心點,那這個任務就好解決了,然而事實是我們無法預先獲取類中心特徵空間的分佈。因此我們只能從訓練的過程中動態的獲取類中心特徵,並對整體的訓練過程產生約束。需要注意的是在訓練的過程中,受限於GPU的顯存等問題,我們不可能直接獲取所有樣本的特徵中心,因此整個過程是基於batch進行的,而且當網絡還未收斂的情況下,網絡得到的特徵中心也是不正確的。基於這兩點,特徵中心的確定勢必是一個基於batch的動態過程。

2.中心點是如何維護的

接下來就詳細講一下這個動態過程,首先提出一個問題:中心點明明是不確定的,那如何讓特徵去聚集在這個不確定的特徵中心點呢?

這要從centerloss的更新機制說起,從下面的兩組公式可以看出,center中心點的更新方向是特徵值和中心點的二範數,簡單來說最終通過這種更新方式會使得某一類特徵值對應的中心點被更新成與所有該類樣本特徵值的二範數和最小的位置,而這個位置我們可以廣義的理解爲所以特徵的中心點位置。因此整體的centerloss是在邊學習邊找中心點的,最終中心點的確定和整體分類任務的收斂是同步進行的。

用知乎上比較概括性的話來講就是:
center loss的原理主要是在softmax loss的基礎上,通過對訓練集的每個類別在特徵空間分別維護一個類中心,在訓練過程,增加樣本經過網絡映射後在特徵空間與類中心的距離約束,從而兼顧了類內聚合與類間分離。

最終通過將centerloss和softmaxloss進行加權求和,實現整體的分類任務的學習。

centerloss的計算代碼:

def forward(self, output_features, y_truth):
        """
        損失計算
        :param output_features: conv層輸出的特徵,  [b,c,h,w]
        :param y_truth:  標籤值  [b,]
        :return:
        """
        batch_size = y_truth.size(0)
        output_features = output_features.view(batch_size, -1)
        assert output_features.size(-1) == self.feat_dim
        factor = self.scale / batch_size
        # return self.lamda * factor * self.lossfunc(output_features, y_truth, self.feature_centers))

        centers_batch = self.feature_centers.index_select(0, y_truth.long())  # [b,features_dim]
        diff = output_features - centers_batch
        loss = self.lamda * 0.5 * factor * (diff.pow(2).sum())
        #########
        return loss

center的更新代碼:

# 改段代碼需要注意的是backward返回值需要與對應的forward的輸入參數一一對應。
class CenterlossFunc(Function):
    @staticmethod
    def forward(ctx, feature, label, centers, batch_size):
        ctx.save_for_backward(feature, label, centers, batch_size)
        centers_batch = centers.index_select(0, label.long())
        return (feature - centers_batch).pow(2).sum() / 2.0 / batch_size

    @staticmethod
    def backward(ctx, grad_output):
        feature, label, centers, batch_size = ctx.saved_tensors
        centers_batch = centers.index_select(0, label.long())
        diff = centers_batch - feature
        # init every iteration
        counts = centers.new_ones(centers.size(0))
        ones = centers.new_ones(label.size(0))
        grad_centers = centers.new_zeros(centers.size())

        counts = counts.scatter_add_(0, label.long(), ones)
        grad_centers.scatter_add_(0, label.unsqueeze(1).expand(feature.size()).long(), diff)
        grad_centers = grad_centers/counts.view(-1, 1)
        return - grad_output * diff / batch_size, None, grad_centers / batch_size, None

pytorch代碼
https://www.cnblogs.com/dxscode/p/12059548.html
https://github.com/jxgu1016/MNIST_center_loss_pytorch/blob/master/CenterLoss.py

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章