Pytorch 網絡模型創建
本方法總結自《動手學深度學習》(Pytorch版)github項目
- 部分內容延續 Pytorch 學習(五):Pytorch 實現多層感知機(MLP) 實現方法
常用的網絡搭建方法有
- 繼承 Module 方法
- 利用 Sequential, ModuleList 和 ModuleDict 類創建
- 多種方法的同時使用
繼承 Module 方法
在 Pytorch 學習(五)中構建多層感知器網絡時,便使用了繼承 torch.nn.Module 的方法,這是最常用的網絡模型創建方法
import torch
import torch.nn as nn
class MLP(nn.Module):
def __init__(self, n_i, n_h, n_o):
super(MLP, self).__init__()
self.linear1 = nn.Linear(n_i, n_h)
self.relu = nn.ReLU()
self.linear2 = nn.Linear(n_h, n_o)
def forward(self, input):
return self.linear2(self.relu(self.linear1(input)))
利用 Sequential 類
同樣對於 MLP 網絡,可以使用 Sequential 類實現
from collections import OrderedDict
net = nn.Sequential(
OrderedDict([
('linear1', nn.Linear(n_inputs, n_hiddens)),
('relu', nn.ReLU()),
('linear2', nn.Linear(n_hiddens, n_outputs))
])
)
利用 Sequential 網絡的各層是有序的,同時不需要實現 forward 函數,默認按照對應順序進行前向傳播。構造一個 MySequential 類來進一步理解
class MySequential(nn.Module):
def __init__(self, *args):
super(MySequential, self).__init__()
if len(args) == 1 and isinstance(args[0], OrderedDict): # 如果傳入 OrderedDict 參數
for key, module in args[0].items():
self.add_module(key, module)
else: # 傳入的爲 module
for idx, module in enumerate(args):
self.add_module(str(idx), module)
def forward(self, input):
for module in self._modules.values():
input = module(input)
return input
同樣利用 MySequential 創建 MLP 網絡
net = MySequential(
nn.Linear(n_inputs, n_hiddens),
nn.ReLU(),
nn.Linear(n_hiddens, n_outputs)
)
print(net)
net = MySequential(
OrderedDict([
('linear1', nn.Linear(n_inputs, n_hiddens)),
('relu', nn.ReLU()),
('linear2', nn.Linear(n_hiddens, n_outputs))
])
)
print(net)
利用 ModuleList 和 ModuleDict
ModuleList 和 ModuleDict 的區別是輸入分別爲 list 和 dict,訪問某一層或添加更多層的方式略有區別
- ModuleList 操作
net = nn.ModuleList([
nn.Linear(n_inputs, n_hiddens),
nn.ReLU(),
])
net.append(nn.Linear(n_hiddens, n_outputs))
print(net[-1])
- ModuleDict 操作
net = nn.ModuleDict({
'linear1', nn.Linear(n_inputs, n_hiddens),
'relu', nn.ReLU()
})
net['linear2'] = nn.Linear(n_hiddens, n_outputs)
print(net['linear2'])
這兩種方式都需要手動實現 forward 函數,不能直接調用。同時與傳統的 list 存在區別,後者的參數不計入 net.parameters()
self.linear1 = nn.ModuleList([nn.Linear(n_inputs, n_outputs)])
print(net.linear1.parameters()[0].shape) # n_inputs, n_outputs
self.linear1 = [nn.Linear(n_inputs, n_outputs)]
print(net.parameters()) # None
總結
- 四種方法實現模型構建:繼承 Module, 使用 Sequential, 使用 ModuleList 和使用 ModuleDict
- 繼承 Module 靈活性最高
- Sequential 有序、自動計算 forward 過程
- ModuleList 和 ModuleDict 僅作爲容器,需要構建 forward 過程