pytorch中獲取模型input/output shape

Pytorch官方目前無法像tensorflow, caffe那樣直接給出shape信息,詳見

https://github.com/pytorch/pytorch/pull/3043


以下代碼算一種workaround。由於CNN, RNN等模塊實現不一樣,添加其他模塊支持可能需要改代碼。

例如RNN中bias是bool類型,其權重也不是存於weight屬性中,不過我們只關注shape夠用了。

該方法必須構造一個輸入調用forward後(model(x)調用)纔可獲取shape


#coding:utf-8
from collections import OrderedDict
import torch
from torch.autograd import Variable
import torch.nn as nn
import models.crnn as crnn
import json


def get_output_size(summary_dict, output):
  if isinstance(output, tuple):
    for i in xrange(len(output)):
      summary_dict[i] = OrderedDict()
      summary_dict[i] = get_output_size(summary_dict[i],output[i])
  else:
    summary_dict['output_shape'] = list(output.size())
  return summary_dict

def summary(input_size, model):
  def register_hook(module):
    def hook(module, input, output):
      class_name = str(module.__class__).split('.')[-1].split("'")[0]
      module_idx = len(summary)

      m_key = '%s-%i' % (class_name, module_idx+1)
      summary[m_key] = OrderedDict()
      summary[m_key]['input_shape'] = list(input[0].size())
      summary[m_key] = get_output_size(summary[m_key], output)

      params = 0
      if hasattr(module, 'weight'):
        params += torch.prod(torch.LongTensor(list(module.weight.size())))
        if module.weight.requires_grad:
          summary[m_key]['trainable'] = True
        else:
          summary[m_key]['trainable'] = False
      #if hasattr(module, 'bias'):
      #  params +=  torch.prod(torch.LongTensor(list(module.bias.size())))

      summary[m_key]['nb_params'] = params
      
    if not isinstance(module, nn.Sequential) and \
       not isinstance(module, nn.ModuleList) and \
       not (module == model):
      hooks.append(module.register_forward_hook(hook))
  
  # check if there are multiple inputs to the network
  if isinstance(input_size[0], (list, tuple)):
    x = [Variable(torch.rand(1,*in_size)) for in_size in input_size]
  else:
    x = Variable(torch.rand(1,*input_size))

  # create properties
  summary = OrderedDict()
  hooks = []
  # register hook
  model.apply(register_hook)
  # make a forward pass
  model(x)
  # remove these hooks
  for h in hooks:
    h.remove()

  return summary

crnn = crnn.CRNN(32, 1, 3755, 256, 1)
x = summary([1,32,128],crnn)
print json.dumps(x)
以pytorch版CRNN爲例,輸出shape如下

{
"Conv2d-1": {
"input_shape": [1, 1, 32, 128],
"output_shape": [1, 64, 32, 128],
"trainable": true,
"nb_params": 576
},
"ReLU-2": {
"input_shape": [1, 64, 32, 128],
"output_shape": [1, 64, 32, 128],
"nb_params": 0
},
"MaxPool2d-3": {
"input_shape": [1, 64, 32, 128],
"output_shape": [1, 64, 16, 64],
"nb_params": 0
},
"Conv2d-4": {
"input_shape": [1, 64, 16, 64],
"output_shape": [1, 128, 16, 64],
"trainable": true,
"nb_params": 73728
},
"ReLU-5": {
"input_shape": [1, 128, 16, 64],
"output_shape": [1, 128, 16, 64],
"nb_params": 0
},
"MaxPool2d-6": {
"input_shape": [1, 128, 16, 64],
"output_shape": [1, 128, 8, 32],
"nb_params": 0
},
"Conv2d-7": {
"input_shape": [1, 128, 8, 32],
"output_shape": [1, 256, 8, 32],
"trainable": true,
"nb_params": 294912
},
"BatchNorm2d-8": {
"input_shape": [1, 256, 8, 32],
"output_shape": [1, 256, 8, 32],
"trainable": true,
"nb_params": 256
},
"ReLU-9": {
"input_shape": [1, 256, 8, 32],
"output_shape": [1, 256, 8, 32],
"nb_params": 0
},
"Conv2d-10": {
"input_shape": [1, 256, 8, 32],
"output_shape": [1, 256, 8, 32],
"trainable": true,
"nb_params": 589824
},
"ReLU-11": {
"input_shape": [1, 256, 8, 32],
"output_shape": [1, 256, 8, 32],
"nb_params": 0
},
"MaxPool2d-12": {
"input_shape": [1, 256, 8, 32],
"output_shape": [1, 256, 4, 33],
"nb_params": 0
},
"Conv2d-13": {
"input_shape": [1, 256, 4, 33],
"output_shape": [1, 512, 4, 33],
"trainable": true,
"nb_params": 1179648
},
"BatchNorm2d-14": {
"input_shape": [1, 512, 4, 33],
"output_shape": [1, 512, 4, 33],
"trainable": true,
"nb_params": 512
},
"ReLU-15": {
"input_shape": [1, 512, 4, 33],
"output_shape": [1, 512, 4, 33],
"nb_params": 0
},
"Conv2d-16": {
"input_shape": [1, 512, 4, 33],
"output_shape": [1, 512, 4, 33],
"trainable": true,
"nb_params": 2359296
},
"ReLU-17": {
"input_shape": [1, 512, 4, 33],
"output_shape": [1, 512, 4, 33],
"nb_params": 0
},
"MaxPool2d-18": {
"input_shape": [1, 512, 4, 33],
"output_shape": [1, 512, 2, 34],
"nb_params": 0
},
"Conv2d-19": {
"input_shape": [1, 512, 2, 34],
"output_shape": [1, 512, 1, 33],
"trainable": true,
"nb_params": 1048576
},
"BatchNorm2d-20": {
"input_shape": [1, 512, 1, 33],
"output_shape": [1, 512, 1, 33],
"trainable": true,
"nb_params": 512
},
"ReLU-21": {
"input_shape": [1, 512, 1, 33],
"output_shape": [1, 512, 1, 33],
"nb_params": 0
},
"LSTM-22": {
"input_shape": [33, 1, 512],
"0": {
"output_shape": [33, 1, 512]
},
"1": {
"0": {
"output_shape": [2, 1, 256]
},
"1": {
"output_shape": [2, 1, 256]
}
},
"nb_params": 0
},
"Linear-23": {
"input_shape": [33, 512],
"output_shape": [33, 256],
"trainable": true,
"nb_params": 131072
},
"BidirectionalLSTM-24": {
"input_shape": [33, 1, 512],
"output_shape": [33, 1, 256],
"nb_params": 0
},
"LSTM-25": {
"input_shape": [33, 1, 256],
"0": {
"output_shape": [33, 1, 512]
},
"1": {
"0": {
"output_shape": [2, 1, 256]
},
"1": {
"output_shape": [2, 1, 256]
}
},
"nb_params": 0
},
"Linear-26": {
"input_shape": [33, 512],
"output_shape": [33, 3755],
"trainable": true,
"nb_params": 1922560
},
"BidirectionalLSTM-27": {
"input_shape": [33, 1, 256],
"output_shape": [33, 1, 3755],
"nb_params": 0
}
}

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章