分類器ArcFace、ArcLoss在MNIST數據集上的實現和效果

分類器ArcFace、ArcLoss在MNIST數據集上的實現和效果

寫在前面:
  前一篇文章(電梯直達)給大家介紹了CenterLoss,本文將帶領大家認識一下ArcFace(ArcLoss、Insightface),並在MNIST數據集上實戰一下看一下效果。

一、原理

在這裏插入圖片描述
  CenterLoss是將每個類別的特徵縮減到他的中心位置,從而間接使不同特徵之間界限分明,而ArcLoss則是在原本兩個特徵之間的夾角上再加上一個角m,然後優化這個m,使夾角慢慢變大,兩個特徵慢慢遠離,從而使不同特徵之間界限分明。他的理想狀態是將所有特徵之間的夾角都最大化,也就是每個特徵都被壓迫成一條細線。
  如下圖中的α和β,同時加上m,會使得α和β同時增大,從而中間紫色區域將會被壓縮,當m優化到最佳,那麼紫色區域將會成爲一條線。當然,這只是一種理想狀態,是一種期望,實際情況下不太可能甚至根本不可能達到這樣的效果。
圖1

二、效果

  原理就是那麼簡簡單單,看下效果
在這裏插入圖片描述

三、訓練說明

  如果只使用ArcLoss,loss只會下降一點點,但也能訓練出來,要是再結合一個其他的損失函數,模型訓練出來的效果將會非常棒,以上效果圖就是ArcLoss + CrossEntropyLoss訓練的結果。
  用ArcLoss代替原來的Softmax作爲輸出函數,初始化ArcLoss時指定輸入特徵維度和分類數,用ArcLoss輸出的數據作爲分類依據,與target計算一次損失,再用模型的輸出與target計算一次損失。

四、實現

""" 網絡輸出層,不加激活 """
fc_feature = nn.Linear(1024, 2, bias=False)
out = nn.Linear(2, 10, bias=False)
feature_out = fc_feature(conv_out)
out = out(feature_out)
return feature_out, out


""" 初始化ArcLoss,這裏是2維特徵,10分類 """
arcface = ArcLoss(2, 10)		# 輸入特徵(N, 2), 輸出(N, 10),輸出可直接拿來分類


""" 損失函數 """
loss_f = torch.nn.CrossEntropyLoss()

""" 優化器:聽說SGD效果更好,感興趣的自己去搗鼓搗鼓 """
self.optimer = torch.optim.Adam([
    {'params': self.net.parameters()},
     {'params': self.arcface.parameters()}
])


""" 訓練器
爲了方便閱讀理解,部分代碼被簡化,
如以下第一句標準語法應爲:for i, (data, target) in enumerate(dataloader)
"""
for data, target in dataloader:
	feature, out = net(data)
	output = arcface(feature)
	arc_loss = loss_f(output, target)			# 計算ArcFace輸出的分類損失
    cls_loss = loss_f(out, target)				# 計算網絡直接輸出的分類損失
    loss = 0.9 * arc_loss + 0.1 * cls_loss		# 如果只計算arcloss,那麼網絡的分類能力會很差
    """acc_arc:arcface輸出的分類正確率;			acc_cls:網絡輸出的out的分類正確率"""
    acc_arc = torch.sum(torch.argmax(output, dim=1) == target) / batch_size
    acc_cls = torch.sum(torch.argmax(out, dim=1) == target) / batch_size


"""ArcLoss函數實現"""
class ArcLoss4(nn.Module):
    def __init__(self, feature_num, class_num, s=10, m=0.1):
    	"""
        :param feature_num:     特徵數
        :param class_num:       類別數
        :param s: 
        :param m:               加上去的夾角,初始爲0.1
        """
        super().__init__()
        self.class_num = class_num
        self.feature_num = feature_num
        self.s = s
        self.m = torch.tensor(m)
        self.w = nn.Parameter(torch.rand(feature_num, class_num), requires_grad=True)  # 2*10
    def forward(self, feature):
        feature = nn.functional.normalize(feature, dim=1)
        w = nn.functional.normalize(self.w, dim=0)
        cos_theat = torch.matmul(feature, w) / 10
        sin_theat = torch.sqrt(1.0 - torch.pow(cos_theat, 2))
        cos_theat_m = cos_theat * torch.cos(self.m) - sin_theat * torch.sin(self.m)
        cos_theat_ = torch.exp(cos_theat * self.s)
        sum_cos_theat = torch.sum(torch.exp(cos_theat * self.s), dim=1, keepdim=True) - cos_theat_
        top = torch.exp(cos_theat_m * self.s)
        div = top / (top + sum_cos_theat)
        return div

# 實現方式2
class ArcLoss2(nn.Module):
    def __init__(self, feature_dim=2, cls_dim=10):
        super().__init__()
        self.W = nn.Parameter(torch.randn(feature_dim, cls_dim), requires_grad=True)

    def forward(self, feature, m=1, s=10):
        x = nn.functional.normalize(feature, dim=1)
        w = nn.functional.normalize(self.W, dim=0)
        cos = torch.matmul(x, w)/10             # 求兩個向量夾角的餘弦值
        a = torch.acos(cos)                     # 反三角函數求得 α
        top = torch.exp(s*torch.cos(a+m))       # e^(s * cos(a + m))
        down2 = torch.sum(torch.exp(s*torch.cos(a)), dim=1, keepdim=True)-torch.exp(s*torch.cos(a))
        out = torch.log(top/(top+down2))
        return out

五、測試

  測試也很簡單,直接把提取器提取的特徵放進arcface去得到輸出,拿輸出做分類,也可以直接拿網絡輸出的out做分類,通過arc的輸出會壓縮在0~1之間,而直接的輸出沒有範圍,但滿足最大值最可靠,用softmax對arc的輸出和net的輸出進行比較可以發現net的out比arc的output更精確,這也解釋了爲什麼要加兩個loss才能訓練得更好。並且我們通常也不會拿他倆的輸出來做分類用

feature, out = net(img_data)		# 將處理好的圖像數據扔進網絡
output = arc(feature)[0]
res1 = torch.argmax(output)			# 通過arc的輸出的分類結果
res2 = torch.argmax(out)			# 模型直接的輸出的分類結果
print('Arc_out:', torch.nn.Softmax(dim=0)(output))
print('Net_out:', torch.nn.Softmax(dim=0)(out))

六、使用

使用不用多講,多講都是廢話,就一句話:
  將圖片處理成數字信號,丟進網絡,拿到特徵feature去做餘弦相似度對比。
  其實不管CenterLoss還是ArcLoss,在做目標相似度對比識別的時候都用不上他,用到的都只是在他前面輸出的那個特徵向量。而如果你是做分類,也用不上他,用到的是網絡最後輸出的分類結果。(當然,如果你非要搞特殊,就要拿arcloss的輸出做分類,那也是ok的)
  總結一下:CenterLoss和ArcLoss都只是在訓練時提高提取器的特徵提取能力,在使用時用不上。

寫在最後

print('Thanks! The end!')		
# 有錯誤之處,歡迎批評指正
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章