torch.nn.Module源碼學習

nn.Module是使用pytorch進行神經網絡訓練的主要載體,是所有網絡的基類。首先看一下它的構造函數:

    def __init__(self):
        """
        Initializes internal Module state, shared by both nn.Module and ScriptModule.
        """
        torch._C._log_api_usage_once("python.nn_module")

        self.training = True
        self._parameters = OrderedDict()
        self._buffers = OrderedDict()
        self._modules = OrderedDict()
        self._backward_hooks = OrderedDict()
        self._forward_hooks = OrderedDict()
        self._forward_pre_hooks = OrderedDict()
        self._state_dict_hooks = OrderedDict()
        self._load_state_dict_pre_hooks = OrderedDict()
  • self.training標誌網絡的狀態,主要影響bn和dropout等在網絡訓練和評估時使用方法不一樣的功能;
  • self._parameters保存當前module的訓練參數;
  • self._buffers保存當前moduile的非訓練參數,如bn的running_mean和running_var;
  • self._modules保存當前module中的子module(不是自定義模型中的所有module);
  • 其餘的屬性均是用戶定義的hook函數。關於hook函數可參數:pytorch的自動求導和hook技術簡介

下面分類來看nn.Module中的其他非輔助方法:

  1. 添加新元素:
  • register_buffer:向self._buffers註冊新元素
  • register_parameter:向self._parameters註冊新元素
  • add_module:向self._modules註冊新元素
  1. 類型轉換:
  • cuda:將所有的parameters和buffers移動到gpu

  • cpu:將所有的parameters和buffers移動到cpu

  • type:將所有的parameters和buffers都轉換爲指定的目標類型

  • float:將所有的parameters和buffers都轉換爲float類型

  • double:將所有的parameters和buffers都轉換爲double類型

  • half:將所有的parameters和buffers都轉換爲float16類型

  • to:該函數有四種用法:

    • to(device=None, dtype=None, non_blocking=False):轉移到指定的device;
    • to(dtype, non_blocking=False):轉換爲指定的dtype;
    • to(tensor, non_blocking=False):將tensor屬性(dtype和device)轉換到與指定tensor相同;
    • to(memory_format=torch.channels_last):改變4d tensor的存儲格式,NCHW或NHWC。
  • 上述所有函數的功能均藉助_apply完成:

    def _apply(self, fn):
            for module in self.children():
                module._apply(fn)
    
            def compute_should_use_set_data(tensor, tensor_applied):
    		# 是否進行in-place操作
    
            for key, param in self._parameters.items():
                if param is not None:
                    # Tensors stored in modules are graph leaves, and we don't want to
                    # track autograd history of `param_applied`, so we have to use
                    # `with torch.no_grad():`
                    with torch.no_grad():
                        param_applied = fn(param)
                    should_use_set_data = compute_should_use_set_data(param, param_applied)
                    if should_use_set_data:
                        # 直接替換舊數據
                        param.data = param_applied
                    else:
                        assert isinstance(param, Parameter)
                        assert param.is_leaf
                        # 註冊新的Parameter
                        self._parameters[key] = Parameter(param_applied, param.requires_grad)
    				
                    # 對param.grad進行相同的操作
                    if param.grad is not None:
                        with torch.no_grad():
                            grad_applied = fn(param.grad)
                        should_use_set_data = compute_should_use_set_data(param.grad, grad_applied)
                        if should_use_set_data:
                            param.grad.data = grad_applied
                        else:
                            assert param.grad.is_leaf
                            self._parameters[key].grad = grad_applied.requires_grad_(param.grad.requires_grad)
    
            # 更新_buffers
            for key, buf in self._buffers.items():
                if buf is not None:
                    self._buffers[key] = fn(buf)
    
            return self
    

    _apply函數遍歷了所有的parameters(如果有grad還要遍歷grad)和buffers,對它們應用fn,並將fn作用的結果註冊到相應的容器中。

  1. hook註冊:
  • register_backward_hook:向self._backward_hooks註冊新元素

  • register_forward_pre_hook:向self._forward_pre_hooks註冊新元素

  • register_forward_hook:向self._forward_hooks註冊新元素

    關於hook技術的詳細介紹可參考我的裏一篇文章pytorch的自動求導和hook技術簡介

  1. 模型保存和加載:
  • state_dict:返回一個包含模型所有parameters和buffers的字典。需要注意的是,所有的參數在被添加到字典中之前,都要經過self._state_dict_hooks中的hook函數處理。

    該函數是個遞歸函數,每次遞歸調用了輔助函數_save_to_state_dict來將當前module的子module參數添加到字典中。

  • load_state_dict:將指定的state_dict中的參數加載到當前的module中。需要注意的是,所有的參數在被添加到當前模型的_parameters和 _ buffers之前,都要經過self. _load_state_dict_pre_hooks中hook函數處理。

    該函數中定義了一個遞歸函數load。load每次遞歸時調用輔助函數_load_from_state_dict來添加當前module的子module參數。

    該函數有一個輸入參數strict,默認爲True。strict爲True時,會返回一個namedtuple。該tuple有兩個屬性:missing_keys和unexpected_keys。missing_keys是存儲當前module有而待加載的state_dict中沒有的參數的列表;unexpected_keys是存儲當前module沒有而待加載的state_dict中有的參數的列表。

  1. 信息查詢:
  • named_modules:返回模型中的所有module(包括模型本身)及其名稱。返回值中有兩個值,第一個爲module的名(string格式),第二個爲相應的module。返回順序是自頂向下。

  • modules:返回模型中的所有module,調用named_modules完成。

  • named_children:返回當前模型_modules中的所有元素及其名稱。注意該函數與named_modules的區別:named_modules是遞歸函數,每次遞歸均查詢當前module的 _modules,進而能夠遍歷模型中的所有module;而named_chilldren只查詢整個模型的 _modules

  • children:返回當前模型_modules中的所有元素,調用named_modules來完成。

  • named_parameters:返回模型中所有的可訓練參數及其名稱。

  • parameters:返回模型中所有的可訓練參數。

  • named_buffers:返回模型中所有的非訓練參數及其名稱。

  • buffers:返回模型中所有的非訓練參數。

    上述所有的方法返回的都是由yield定義的生成器。
    前兩個方法藉助named_modules來完成。

    def named_modules(self, memo=None, prefix=''):
          r"""Returns an iterator over all modules in the network, yielding
          both the name of the module as well as the module itself.
          Yields:
              (string, Module): Tuple of name and module
          Note:
              Duplicate modules are returned only once. In the following
              example, ``l`` will be returned only once.
          """
          if memo is None:
              memo = set()
          if self not in memo:
              memo.add(self)  # 包括當前模型自身
              yield prefix, self
            for name, module in self._modules.items():
                  if module is None:
                    continue
                  submodule_prefix = prefix + ('.' if prefix else '') + name
                  # 遞歸返回每一級所有module
                  for m in module.named_modules(memo, submodule_prefix):
                      yield m
    

    named_children和children直接返回_modules中的所有元素。
    最後四個方法調用了輔助函數_named_members來實現:

        def _named_members(self, get_members_fn, prefix='', recurse=True):
            r"""Helper method for yielding various names + members of modules."""
            memo = set()
            # 調用named_modules獲得模型中的所有module
            modules = self.named_modules(prefix=prefix) if recurse else [(prefix, self)]
            for module_prefix, module in modules:
                members = get_members_fn(module) # 獲得當前module的所有參數
                for k, v in members:
                    if v is None or v in memo:
                        continue
                    memo.add(v)
                    name = module_prefix + ('.' if module_prefix else '') + k
                    yield name, v
    

用一個簡單的示例來學習他們的區別:

import torch
import torch.nn as nn

class net(nn.Module):
    def __init__(self):
        super(net, self).__init__()

        self.conv = nn.Sequential(
            nn.Conv2d(3, 32, 3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
        )
        self.fc = nn.Linear(32, 3)

    def forward(self, x):
        x = self.conv(x)
        x = x.view(-1, 32)
        x = self.fc(x)
        return x

model = net()

for k, m in model._modules.items():
    print(k, m)
# conv Sequential(
#   (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
#   (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
#   (2): ReLU(inplace=True)
# )
# fc Linear(in_features=32, out_features=3, bias=True)

for k, v in model.named_modules():
	print(k, v)
#  net(
#   (conv): Sequential(
#     (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
#     (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
#     (2): ReLU(inplace=True)
#   )
#   (fc): Linear(in_features=32, out_features=3, bias=True)
# )
# conv Sequential(
#   (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
#   (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
#   (2): ReLU(inplace=True)
# )
# conv.0 Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
# conv.1 BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
# conv.2 ReLU(inplace=True)
# fc Linear(in_features=32, out_features=3, bias=True)

for k, v in model.named_children():
	print(k, v)
# conv Sequential(
#   (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
#   (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
#   (2): ReLU(inplace=True)
# )
# fc Linear(in_features=32, out_features=3, bias=True)
# 注意children()和modules()的區別:children只返回一級子module

for k, _ in model.named_parameters():
	print(k)
# conv.0.weight
# conv.1.weight
# conv.1.bias
# fc.weight
# fc.bias

for k, _ in model.named_buffers():
	print(k)
# conv.1.running_mean
# conv.1.running_var
# conv.1.num_batches_tracked
  1. 狀態設置:
  • train:將模型設爲訓練模式;主要影響BN、Dropout等module。
  • eval:將模型設爲評估模式;主要影響BN、Dropout等module。調用train來實現。
  • require_grad_:設置模型中所有的Parameter的require_grad屬性,即是否要計算梯度。調用parameters()來實現。
  • zero_grad:將模型中所有的Parameter的梯度置爲零,並將其從計算圖中分離。調用parameters()來實現。
  1. 其他:
  • apply:將自定義函數作用於模型的所有子module。

    def apply(self, fn):
        for module in self.children():
            module.apply(fn)
            fn(self)
            return self
    
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章