分類器ArcFace、ArcLoss在MNIST數據集上的實現和效果
寫在前面:
前一篇文章(電梯直達)給大家介紹了CenterLoss,本文將帶領大家認識一下ArcFace(ArcLoss、Insightface),並在MNIST數據集上實戰一下看一下效果。
一、原理
CenterLoss是將每個類別的特徵縮減到他的中心位置,從而間接使不同特徵之間界限分明,而ArcLoss則是在原本兩個特徵之間的夾角上再加上一個角m,然後優化這個m,使夾角慢慢變大,兩個特徵慢慢遠離,從而使不同特徵之間界限分明。他的理想狀態是將所有特徵之間的夾角都最大化,也就是每個特徵都被壓迫成一條細線。
如下圖中的α和β,同時加上m,會使得α和β同時增大,從而中間紫色區域將會被壓縮,當m優化到最佳,那麼紫色區域將會成爲一條線。當然,這只是一種理想狀態,是一種期望,實際情況下不太可能甚至根本不可能達到這樣的效果。
二、效果
原理就是那麼簡簡單單,看下效果
三、訓練說明
如果只使用ArcLoss,loss只會下降一點點,但也能訓練出來,要是再結合一個其他的損失函數,模型訓練出來的效果將會非常棒,以上效果圖就是ArcLoss + CrossEntropyLoss訓練的結果。
用ArcLoss代替原來的Softmax作爲輸出函數,初始化ArcLoss時指定輸入特徵維度和分類數,用ArcLoss輸出的數據作爲分類依據,與target計算一次損失,再用模型的輸出與target計算一次損失。
四、實現
""" 網絡輸出層,不加激活 """
fc_feature = nn.Linear(1024, 2, bias=False)
out = nn.Linear(2, 10, bias=False)
feature_out = fc_feature(conv_out)
out = out(feature_out)
return feature_out, out
""" 初始化ArcLoss,這裏是2維特徵,10分類 """
arcface = ArcLoss(2, 10) # 輸入特徵(N, 2), 輸出(N, 10),輸出可直接拿來分類
""" 損失函數 """
loss_f = torch.nn.CrossEntropyLoss()
""" 優化器:聽說SGD效果更好,感興趣的自己去搗鼓搗鼓 """
self.optimer = torch.optim.Adam([
{'params': self.net.parameters()},
{'params': self.arcface.parameters()}
])
""" 訓練器
爲了方便閱讀理解,部分代碼被簡化,
如以下第一句標準語法應爲:for i, (data, target) in enumerate(dataloader)
"""
for data, target in dataloader:
feature, out = net(data)
output = arcface(feature)
arc_loss = loss_f(output, target) # 計算ArcFace輸出的分類損失
cls_loss = loss_f(out, target) # 計算網絡直接輸出的分類損失
loss = 0.9 * arc_loss + 0.1 * cls_loss # 如果只計算arcloss,那麼網絡的分類能力會很差
"""acc_arc:arcface輸出的分類正確率; acc_cls:網絡輸出的out的分類正確率"""
acc_arc = torch.sum(torch.argmax(output, dim=1) == target) / batch_size
acc_cls = torch.sum(torch.argmax(out, dim=1) == target) / batch_size
"""ArcLoss函數實現"""
class ArcLoss4(nn.Module):
def __init__(self, feature_num, class_num, s=10, m=0.1):
"""
:param feature_num: 特徵數
:param class_num: 類別數
:param s:
:param m: 加上去的夾角,初始爲0.1
"""
super().__init__()
self.class_num = class_num
self.feature_num = feature_num
self.s = s
self.m = torch.tensor(m)
self.w = nn.Parameter(torch.rand(feature_num, class_num), requires_grad=True) # 2*10
def forward(self, feature):
feature = nn.functional.normalize(feature, dim=1)
w = nn.functional.normalize(self.w, dim=0)
cos_theat = torch.matmul(feature, w) / 10
sin_theat = torch.sqrt(1.0 - torch.pow(cos_theat, 2))
cos_theat_m = cos_theat * torch.cos(self.m) - sin_theat * torch.sin(self.m)
cos_theat_ = torch.exp(cos_theat * self.s)
sum_cos_theat = torch.sum(torch.exp(cos_theat * self.s), dim=1, keepdim=True) - cos_theat_
top = torch.exp(cos_theat_m * self.s)
div = top / (top + sum_cos_theat)
return div
# 實現方式2
class ArcLoss2(nn.Module):
def __init__(self, feature_dim=2, cls_dim=10):
super().__init__()
self.W = nn.Parameter(torch.randn(feature_dim, cls_dim), requires_grad=True)
def forward(self, feature, m=1, s=10):
x = nn.functional.normalize(feature, dim=1)
w = nn.functional.normalize(self.W, dim=0)
cos = torch.matmul(x, w)/10 # 求兩個向量夾角的餘弦值
a = torch.acos(cos) # 反三角函數求得 α
top = torch.exp(s*torch.cos(a+m)) # e^(s * cos(a + m))
down2 = torch.sum(torch.exp(s*torch.cos(a)), dim=1, keepdim=True)-torch.exp(s*torch.cos(a))
out = torch.log(top/(top+down2))
return out
五、測試
測試也很簡單,直接把提取器提取的特徵放進arcface去得到輸出,拿輸出做分類,也可以直接拿網絡輸出的out做分類,通過arc的輸出會壓縮在0~1之間,而直接的輸出沒有範圍,但滿足最大值最可靠,用softmax對arc的輸出和net的輸出進行比較可以發現net的out比arc的output更精確,這也解釋了爲什麼要加兩個loss才能訓練得更好。並且我們通常也不會拿他倆的輸出來做分類用
feature, out = net(img_data) # 將處理好的圖像數據扔進網絡
output = arc(feature)[0]
res1 = torch.argmax(output) # 通過arc的輸出的分類結果
res2 = torch.argmax(out) # 模型直接的輸出的分類結果
print('Arc_out:', torch.nn.Softmax(dim=0)(output))
print('Net_out:', torch.nn.Softmax(dim=0)(out))
六、使用
使用不用多講,多講都是廢話,就一句話:
將圖片處理成數字信號,丟進網絡,拿到特徵feature去做餘弦相似度對比。
其實不管CenterLoss還是ArcLoss,在做目標相似度對比識別的時候都用不上他,用到的都只是在他前面輸出的那個特徵向量。而如果你是做分類,也用不上他,用到的是網絡最後輸出的分類結果。(當然,如果你非要搞特殊,就要拿arcloss的輸出做分類,那也是ok的)
總結一下:CenterLoss和ArcLoss都只是在訓練時提高提取器的特徵提取能力,在使用時用不上。
寫在最後
print('Thanks! The end!')
# 有錯誤之處,歡迎批評指正