dropout實現過程

1、dropout可以用來防止過擬合
pytorch中實現如下:

m = nn.Dropout(p=0.2)
input = torch.randn(2, 5)
print()
output = m(input)
print(input)
print(output)

輸出如下
在這裏插入圖片描述
實際上,dropout不只mask掉某個位置的數,而且還將保留的數進行縮放,縮放比例爲
p1p{\frac{p}{1-p}}
這裏計算出來縮放比例爲1.25。-0.1673*1.25=0.209125
dropout的作用在於對某個字(seq維度上)一共有384個特徵,隨機剔除幾個特徵,並放大剩餘特徵。

2、Dropout(x)的後續作用
Dropout: [batch_size, seq, hidden_dim]
W: [hidden_dim, 10]
在這裏插入圖片描述
上圖中Dropout(x)維度爲[1, 2, 4],對圖中Dropout(x)的第一行來說,W的第二行數據失效,第二行類似。因爲seq維度不一定是獨立的。
因爲每個batch中的數據兩兩之間可以認爲是獨立的,對Dropout(x)每一行來說,都相當於獨立訓練一個分類器。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章