Pytorch控制打印矩陣的格式

原文地址

分類目錄——Pytorch

諸如長序列單行顯示,全部顯示(不縮略顯示),精度(保留小數點後幾位),是否科學計數法顯示等等。

直接用程序來說明

  • 生成測試數據

    import torch
    
    torch.random.manual_seed(0)	# 固定每次生成的數據相同
    
    tensor = torch.rand(100, 9)-0.5
    print(tensor)	# 在默認的顯示設置下進行print
    

    效果(部分)如下圖

    1585033499309

    可以看到默認設定,過長的會換行,默認保留小數點後4位,默認進行科學計數法顯示

  • 通過torch.set_printoptions()控制顯示格式

    torch.set_printoptions(
        precision=2,    # 精度,保留小數點後幾位,默認4
        threshold=1000,
        edgeitems=3,
        linewidth=150,  # 每行最多顯示的字符數,默認80,超過則換行顯示
        profile=None,
        sci_mode=False  # 用科學技術法顯示數據,默認True
    )
    print(tensor)
    

    其中

    • precision=2

      精度,保留小數點後幾位,默認4

    • threshold=1000
      最多可現實的Array元素個數,默認1000;
      限制的是基本元素個數,如3*5的矩陣,限制的是15而非3(行);
      如果超過就採用縮略顯示;
      設置爲inf全部顯示

    • edgeitems=3
      在縮略顯示時在起始和默認顯示的元素個數(對多個維度同時有效)

    • linewidth=150,

      每行最多顯示的字符數,默認80,超過則換行顯示

    • profile=None

      3種預定義的顯示模板,可選’default’、‘short’、‘full’

      	# if profile == "default":
          #     PRINT_OPTS.precision = 4
          #     PRINT_OPTS.threshold = 1000
          #     PRINT_OPTS.edgeitems = 3
          #     PRINT_OPTS.linewidth = 80
          # elif profile == "short":
          #     PRINT_OPTS.precision = 2
          #     PRINT_OPTS.threshold = 1000
          #     PRINT_OPTS.edgeitems = 2
          #     PRINT_OPTS.linewidth = 80
          # elif profile == "full":
          #     PRINT_OPTS.precision = 4
          #     PRINT_OPTS.threshold = inf
          #     PRINT_OPTS.edgeitems = 3
          #     PRINT_OPTS.linewidth = 80
      
    • sci_mode=False

      用科學技術法顯示數據,默認True

    如上設定之後顯示如下

    1585033983750

    應該注意這只是設置print出來的顯示格式,並不會影響到數據在內存中運算的精度

  • 相關文獻

    Numpy控制多維矩陣的顯示格式

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章