在每個批內pad一個tensor

當一個Dataset處理完數據後,需要加載時,希望在一個mini batch內pad數據,把數據pad成這個批內最大的長度,減小不必要的顯存消耗。

torch給提供了這樣的函數,在torch.nn.utils.rnn.pad_sequence函數。

# 函數返回一個T x B x * 或 B x T x *的一個tensor,當batch_first=True時,B在前面。
# 需要pad的tensor的維度放在第一個維度上。
"""
This function returns a Tensor of size ``T x B x *`` or ``B x T x *``
        where `T` is the length of the longest sequence. This function assumes
        trailing dimensions and type of all the Tensors in sequences are same.
"""
>>> from torch.nn.utils.rnn import pad_sequence
>>> a = torch.ones(25, 300)
>>> b = torch.ones(22, 300)
>>> c = torch.ones(15, 300)
>>> pad_sequence([a, b, c]).size()
torch.Size([25, 3, 300])

一維tensor的pad:

>>> a
tensor([1., 1., 1.])
>>> b
tensor([1., 1., 1., 1., 1., 1.])
>>> c
tensor([1., 1., 1., 1., 1., 1., 1., 1.])
>>> pad_sequence([a,b,c],batch_first=True).size()
torch.Size([3, 8])
>>> pad_sequence([a,b,c],batch_first=True)
tensor([[1., 1., 1., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1.]])

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章