關於transformer-xl中rel-shift實現的解讀


 

方法

抽象地看,我們要做的事情就是,給定一個矩陣,每行都進行左移,而移動的個數隨行數遞增而遞減。

我目前想到的一種方法是使用gather,將想要的index提前定好,然後使用Pytorch的gather就能夠實現。

而transformer-xl實現了另一種更好的方法:_rel_shift

def _rel_shift(self, x, zero_triu=False):
    # x: q,k,bs,n_head
    zero_pad = torch.zeros((x.size(0), 1, *x.size()[2:]),
                           device=x.device, dtype=x.dtype)
    x_padded = torch.cat([zero_pad, x], dim=1)

    x_padded = x_padded.view(x.size(1) + 1, x.size(0), *x.size()[2:])

    x = x_padded[1:].view_as(x)

    return x

第一步是,將x的第一列填上padding,此時x.size()=q,k+1,bs,n_head,接下來將其重新reshape,則變成了x.size()=k+1,q,bs,n_head,最後將第一行去掉,變成x.size()=k,q,bs,n_head,再將其reshape回x原來的樣子。

爲什麼這麼做實現了我們想要的左移的功能?我們應該從一維的角度去理解。因爲實際上在內存中所有元素都是按照一維去排列的。

原來的矩陣:

實際上就是有q個key按照一行去排列。

在做完padding之後,則:

實際上就是在每個key前面插入了0。

接下來view,實際上數據的先後順序還是沒有變(因爲不是transpose):

實際上只是強行將該行切成一個一個q而已。

那麼最後一個操作,將第一行丟掉,實際上就是要把原來的x的第一行強行左移q-1個(因爲有padding)。那麼爲什麼後面的行能夠左移的個數依次減少?別忘了padding,第一行左移了q-1個,但第二個key前面也有一個padding,所以相當於將其向右推了一格;第三個又有一個padding,就在原來的基礎上又推了一格,也即推了兩格。因此最後達到了我們想要的目的。

實際上要理解該方法,需要牢牢把握數據存儲的本質是一整行。

該方法沒有數據的拷貝,全部都是view操作,因此更高效。

不得不佩服想到該方法的人的工程能力,同時也感謝戴寧帶我理解該方法的本質,一開始我是死活不理解的。以後或許可以將該思想靈活應用到其他方面。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章