Pytorch中的view()函數

原文地址

Pytorch系列目錄

view()函數有些像numpy中的reshape函數,是用來的tensor(張量)形式的數據進行圍堵重構的,直接用程序來說明用法

  • 生成測試數據

    import torch
    
    torch.manual_seed(0)	# 用來控制內部的隨機機制使每次得到的隨機數一樣
    
    tt = torch.rand(3,4)
    # tensor([[0.4963, 0.7682, 0.0885, 0.1320],
    #         [0.3074, 0.6341, 0.4901, 0.8964],
    #         [0.4556, 0.6323, 0.3489, 0.4017]])
    
  • 實現方法

    print(tt.view((2,-1)))
    # tensor([[0.4963, 0.7682, 0.0885, 0.1320, 0.3074, 0.6341],
    #         [0.4901, 0.8964, 0.4556, 0.6323, 0.3489, 0.4017]])
    print(tt.view(2,-1))
    # tensor([[0.4963, 0.7682, 0.0885, 0.1320, 0.3074, 0.6341],
    #         [0.4901, 0.8964, 0.4556, 0.6323, 0.3489, 0.4017]])
    

    其中-1表示不對這一維度的數量做具體限定,算出來是多少就是多少,注意在所有維度中只能有一個維度指定爲-1

    view()函數可以接收兩種形式的輸入,一種是給出一個‘形狀’ (2,-1),一種是一次列舉各維度的維度值2,-1

  • 可以用reshape()函數實現

    pytorch提供了很好的numpy兼容性,很多numpy下的方法在pytorch中也可以使用,用reshape()函數實現方式和實現結果與view相同

    print(tt.reshape((2,-1)))
    # tensor([[0.4963, 0.7682, 0.0885, 0.1320, 0.3074, 0.6341],
    #         [0.4901, 0.8964, 0.4556, 0.6323, 0.3489, 0.4017]])
    print(tt.reshape(2,-1))
    # tensor([[0.4963, 0.7682, 0.0885, 0.1320, 0.3074, 0.6341],
    #         [0.4901, 0.8964, 0.4556, 0.6323, 0.3489, 0.4017]])
    
  • 再多做一點兒,三維(2,2,-1)

    print(tt.reshape(2,2,-1))
    # tensor([[[0.4963, 0.7682, 0.0885],
    #          [0.1320, 0.3074, 0.6341]],
    # 
    #         [[0.4901, 0.8964, 0.4556],
    #          [0.6323, 0.3489, 0.4017]]])
    

    要在2維空間裏print3維的數據,大概就是這樣了吧

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章