1、dropout可以用來防止過擬合
pytorch中實現如下:
m = nn.Dropout(p=0.2)
input = torch.randn(2, 5)
print()
output = m(input)
print(input)
print(output)
輸出如下
實際上,dropout不只mask掉某個位置的數,而且還將保留的數進行縮放,縮放比例爲
這裏計算出來縮放比例爲1.25。-0.1673*1.25=0.209125
dropout的作用在於對某個字(seq維度上)一共有384個特徵,隨機剔除幾個特徵,並放大剩餘特徵。
2、Dropout(x)的後續作用
Dropout: [batch_size, seq, hidden_dim]
W: [hidden_dim, 10]
上圖中Dropout(x)維度爲[1, 2, 4],對圖中Dropout(x)的第一行來說,W的第二行數據失效,第二行類似。因爲seq維度不一定是獨立的。
因爲每個batch中的數據兩兩之間可以認爲是獨立的,對Dropout(x)每一行來說,都相當於獨立訓練一個分類器。