方法
抽象地看,我們要做的事情就是,給定一個矩陣,每行都進行左移,而移動的個數隨行數遞增而遞減。
我目前想到的一種方法是使用gather,將想要的index提前定好,然後使用Pytorch的gather就能夠實現。
而transformer-xl實現了另一種更好的方法:_rel_shift
。
def _rel_shift(self, x, zero_triu=False): # x: q,k,bs,n_head zero_pad = torch.zeros((x.size(0), 1, *x.size()[2:]), device=x.device, dtype=x.dtype) x_padded = torch.cat([zero_pad, x], dim=1) x_padded = x_padded.view(x.size(1) + 1, x.size(0), *x.size()[2:]) x = x_padded[1:].view_as(x) return x |
第一步是,將x的第一列填上padding,此時x.size()=q,k+1,bs,n_head
,接下來將其重新reshape,則變成了x.size()=k+1,q,bs,n_head
,最後將第一行去掉,變成x.size()=k,q,bs,n_head
,再將其reshape回x原來的樣子。
爲什麼這麼做實現了我們想要的左移的功能?我們應該從一維的角度去理解。因爲實際上在內存中所有元素都是按照一維去排列的。
實際上就是有q個key按照一行去排列。
實際上就是在每個key前面插入了0。
接下來view,實際上數據的先後順序還是沒有變(因爲不是transpose):
實際上只是強行將該行切成一個一個q而已。
那麼最後一個操作,將第一行丟掉,實際上就是要把原來的x的第一行強行左移q-1個(因爲有padding)。那麼爲什麼後面的行能夠左移的個數依次減少?別忘了padding,第一行左移了q-1個,但第二個key前面也有一個padding,所以相當於將其向右推了一格;第三個又有一個padding,就在原來的基礎上又推了一格,也即推了兩格。因此最後達到了我們想要的目的。
實際上要理解該方法,需要牢牢把握數據存儲的本質是一整行。
該方法沒有數據的拷貝,全部都是view操作,因此更高效。
不得不佩服想到該方法的人的工程能力,同時也感謝戴寧帶我理解該方法的本質,一開始我是死活不理解的。以後或許可以將該思想靈活應用到其他方面。
- 本文作者: 林澤輝
- 本文鏈接: http://www.linzehui.me/2019/05/07/代碼相關/關於transformer-xl中rel-shift實現的解讀/
- 版權聲明: 本博客所有文章除特別聲明外,均採用 CC BY-NC-SA 3.0 許可協議。轉載請註明出處!