saving and loading models - 保存和讀取模型

saving and loading models - 保存和讀取模型

https://pytorch.org/tutorials/beginner/saving_loading_models.html
https://github.com/pytorch/tutorials/blob/master/beginner_source/saving_loading_models.py
https://pytorch.org/tutorials/index.html

This document provides solutions to a variety of use cases regarding the saving and loading of PyTorch models. Feel free to read the whole document, or just skip to the code you need for a desired use case.
本文檔爲有關保存和加載 PyTorch 模型的各種用例提供瞭解決方案。隨意閱讀整個文檔,或者只是跳到所需用例所需的代碼。

When it comes to saving and loading models, there are three core functions to be familiar with:
關於保存和加載模型,有三個核心功能需要熟悉:

  • torch.save: Saves a serialized object to disk. This function uses Python’s pickle utility for serialization. Models, tensors, and dictionaries of all kinds of objects can be saved using this function.
    保存一個序列化對象 (serialized object) 到磁盤中。該函數使用 Python 的 pickle 工具完成序列化。Models, tensors, 以及由各種對象所組成的字典數據都可以通過該函數進行保存。

  • torch.load: Uses pickle’s unpickling facilities to deserialize pickled object files to memory. This function also facilitates the device to load the data into (see Saving & Loading Model Across Devices).
    使用 pickle 的解包工具 (unpickling facilities) 來反序列化 pickled object 文件到內存中。該函數同樣可以操作設備 (device) 來加載數據 (see Saving & Loading Model Across Devices)。

  • torch.nn.Module.load_state_dict: Loads a model’s parameter dictionary using a deserialized state_dict. For more information on state_dict, see What is a state_dict?.
    使用反序列化的 state_dict 加載模型的參數字典。

1. What is a state_dict?

In PyTorch, the learnable parameters (i.e. weights and biases) of an torch.nn.Module model are contained in the model’s parameters (accessed with model.parameters()). A state_dict is simply a Python dictionary object that maps each layer to its parameter tensor. Note that only layers with learnable parameters (convolutional layers, linear layers, etc.) and registered buffers (batchnorm’s running_mean) have entries in the model’s state_dict. Optimizer objects (torch.optim) also have a state_dict, which contains information about the optimizer’s state, as well as the hyperparameters used.
在 PyTorch 中,torch.nn.Module 模型中的可學習參數 (weighs and biases) 保存在模型參數中 (accessed with model.parameters())。而 state_dict 是一個典型的 Python 字典對象,它將每個層映射到它的參數張量。注意,只有具有可學習參數的層 (convolutional layers, linear layers, etc.) 纔會被保存在模型的 state_dict 數據結構中。優化器對象 (Optimizer objects torch.optim) 也有一個 state_dict,其中包含了優化器狀態以及所使用的超參數的信息。

Because state_dict objects are Python dictionaries, they can be easily saved, updated, altered, and restored, adding a great deal of modularity to PyTorch models and optimizers.
由於 state_dict 對象是 Python 字典,因此可以輕鬆地保存、更新、更改和還原它們,從而爲 PyTorch 模型和優化器增加了很多模塊化。

modularity /,mɒdjʊ'lærɪtɪ/:n. 模塊性

Let’s take a look at the state_dict from the simple model used in the Training a classifier tutorial.
https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html

# Define model
class TheModelClass(nn.Module):
    def __init__(self):
        super(TheModelClass, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# Initialize model
model = TheModelClass()

# Initialize optimizer
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

# Print model's state_dict
print("Model's state_dict:")
for param_tensor in model.state_dict():
    print(param_tensor, "\t", model.state_dict()[param_tensor].size())

# Print optimizer's state_dict
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
    print(var_name, "\t", optimizer.state_dict()[var_name])
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# yongqiang cheng

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim


# Define model
class TheModelClass(nn.Module):
    def __init__(self):
        super(TheModelClass, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


# Initialize model
model = TheModelClass()

# Initialize optimizer
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

# Print model's state_dict
print("Model's state_dict:")
for param_tensor in model.state_dict():
    print(param_tensor, "\t", model.state_dict()[param_tensor].size())

# Print optimizer's state_dict
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
    print(var_name, "\t", optimizer.state_dict()[var_name])

Output:

/home/yongqiang/miniconda3/envs/pt-1.4_py-3.6/bin/python /home/yongqiang/pytorch_work/end2end-asr-pytorch-example/yongqiang.py
Model's state_dict:
conv1.weight 	 torch.Size([6, 3, 5, 5])
conv1.bias 	 torch.Size([6])
conv2.weight 	 torch.Size([16, 6, 5, 5])
conv2.bias 	 torch.Size([16])
fc1.weight 	 torch.Size([120, 400])
fc1.bias 	 torch.Size([120])
fc2.weight 	 torch.Size([84, 120])
fc2.bias 	 torch.Size([84])
fc3.weight 	 torch.Size([10, 84])
fc3.bias 	 torch.Size([10])
Optimizer's state_dict:
state 	 {}
param_groups 	 [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [139998270550880, 139998249461728, 139998249461800, 139998249461872, 139998249461944, 139998249462016, 139998249462088, 139998249462160, 139998249462232, 139998249462304]}]

Process finished with exit code 0

2. Saving & Loading Model for Inference

2.1 Save/Load state_dict (Recommended)

Save:

torch.save(model.state_dict(), PATH)

Load:

model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()

When saving a model for inference, it is only necessary to save the trained model’s learned parameters. Saving the model’s state_dict with the torch.save() function will give you the most flexibility for restoring the model later, which is why it is the recommended method for saving models.
當爲 inference 階段保存模型時,僅僅保存訓練好的模型的可更新參數即可。利用 torch.save() 函數來保存模型的state_dict 可以在之後恢復模型時提供極大的靈活性,這也是我們推薦使用該方法來保存模型的原因。

A common PyTorch convention is to save models using either a .pt or .pth file extension.
常見的 PyTorch 約定是使用 .pt.pth 文件擴展名保存模型。

Remember that you must call model.eval() to set dropout and batch normalization layers to evaluation mode before running inference. Failing to do this will yield inconsistent inference results.
請記住,在運行推理之前,必須調用 model.eval()dropout and batch normalization 層設置爲評估模式。不這樣做將產生不一致的推斷結果。

Notice that the load_state_dict() function takes a dictionary object, NOT a path to a saved object. This means that you must deserialize the saved state_dict before you pass it to the load_state_dict() function. For example, you CANNOT load using model.load_state_dict(PATH).
請注意,load_state_dict() 函數接受的參數是一個字典對象, 而不是模型文件的保存路徑。這意味着你必須先將模型文件解序列成字典以後,才能將其傳給 load_state_dict() 函數。例如,您不能使用 model.load_state_dict(PATH) 加載。

2.2 Save/Load Entire Model

Save:

torch.save(model, PATH)

Load:

# Model class must be defined somewhere
model = torch.load(PATH)
model.eval()

This save/load process uses the most intuitive syntax and involves the least amount of code. Saving a model in this way will save the entire module using Python’s pickle module. The disadvantage of this approach is that the serialized data is bound to the specific classes and the exact directory structure used when the model is saved. The reason for this is because pickle does not save the model class itself. Rather, it saves a path to the file containing the class, which is used during load time. Because of this, your code can break in various ways when used in other projects or after refactors.
此保存/加載過程使用最直觀的語法,並且涉及最少的代碼。以這種方式保存模型將使用 Python 的 pickle 模塊保存整個模塊。這種方法的缺點是序列化的數據綁定到特定的類,並且在保存模型時使用確切的目錄結構。這樣做的原因是因爲 pickle 不會保存模型類本身。而是將其保存到包含類的文件的路徑,該路徑在加載時使用。因此,在其他項目中使用或重構後,您的代碼可能會以各種方式崩潰。

A common PyTorch convention is to save models using either a .pt or .pth file extension.
常見的 PyTorch 約定是使用 .pt.pth 文件擴展名保存模型。

Remember that you must call model.eval() to set dropout and batch normalization layers to evaluation mode before running inference. Failing to do this will yield inconsistent inference results.
請記住,在運行推理之前,必須調用 model.eval()dropout and batch normalization 層設置爲評估模式。不這樣做將產生不一致的推斷結果。

3. Saving & Loading a General Checkpoint for Inference and/or Resuming Training

Save:

torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss,
            ...
            }, PATH)

Load:

model = TheModelClass(*args, **kwargs)
optimizer = TheOptimizerClass(*args, **kwargs)

checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

model.eval()
# - or -
model.train()

When saving a general checkpoint, to be used for either inference or resuming training, you must save more than just the model’s state_dict. It is important to also save the optimizer’s state_dict, as this contains buffers and parameters that are updated as the model trains. Other items that you may want to save are the epoch you left off on, the latest recorded training loss, external torch.nn.Embedding layers, etc.
保存用於檢查或繼續訓練的常規檢查點時,您必須保存的不只是模型的 state_dict。保存優化器的 state_dict 也很重要,因爲它包含隨着模型訓練而更新的緩衝區和參數。您可能要保存的其他項包括您停止時的 epoch,最新記錄的訓練損失,外部 torch.nn.Embedding 層等。

To save multiple components, organize them in a dictionary and use torch.save() to serialize the dictionary. A common PyTorch convention is to save these checkpoints using the .tar file extension.
要保存多個部分,請將它們組織在字典中,然後使用 torch.save() 序列化字典。常見的 PyTorch 約定是使用 .tar 文件擴展名保存這些檢查點。

To load the items, first initialize the model and optimizer, then load the dictionary locally using torch.load(). From here, you can easily access the saved items by simply querying the dictionary as you would expect.
要加載項,請首先初始化模型和優化器,然後使用 torch.load() 在本地加載字典。從這裏,您只需按期望查詢字典即可輕鬆訪問已保存的項目。

Remember that you must call model.eval() to set dropout and batch normalization layers to evaluation mode before running inference. Failing to do this will yield inconsistent inference results. If you wish to resuming training, call model.train() to ensure these layers are in training mode.
請記住,在運行推理之前,必須調用 model.eval() 來將 dropout and batch normalization 層設置爲評估模式。不這樣做將產生不一致的推斷結果。如果您希望恢復訓練,請調用 model.train() 以確保這些層處於訓練模式。

4. Saving Multiple Models in One File

Save:

torch.save({
            'modelA_state_dict': modelA.state_dict(),
            'modelB_state_dict': modelB.state_dict(),
            'optimizerA_state_dict': optimizerA.state_dict(),
            'optimizerB_state_dict': optimizerB.state_dict(),
            ...
            }, PATH)

Load:

modelA = TheModelAClass(*args, **kwargs)
modelB = TheModelBClass(*args, **kwargs)
optimizerA = TheOptimizerAClass(*args, **kwargs)
optimizerB = TheOptimizerBClass(*args, **kwargs)

checkpoint = torch.load(PATH)
modelA.load_state_dict(checkpoint['modelA_state_dict'])
modelB.load_state_dict(checkpoint['modelB_state_dict'])
optimizerA.load_state_dict(checkpoint['optimizerA_state_dict'])
optimizerB.load_state_dict(checkpoint['optimizerB_state_dict'])

modelA.eval()
modelB.eval()
# - or -
modelA.train()
modelB.train()

When saving a model comprised of multiple torch.nn.Modules, such as a GAN, a sequence-to-sequence model, or an ensemble of models, you follow the same approach as when you are saving a general checkpoint. In other words, save a dictionary of each model’s state_dict and corresponding optimizer. As mentioned before, you can save any other items that may aid you in resuming training by simply appending them to the dictionary.
保存由多個 torch.nn.Modules 組成的模型 (例如 GAN,序列到序列模型或模型集合) 時,您將採用與保存常規檢查點相同的方法。換句話說,保存每個模型的 state_dict 和相應優化器的字典。如前所述,您可以保存任何其他可以幫助您恢復訓練的項,只需將它們添加到字典中即可。

A common PyTorch convention is to save these checkpoints using the .tar file extension.
常見的 PyTorch 約定是使用 .tar 文件擴展名保存這些檢查點。

To load the models, first initialize the models and optimizers, then load the dictionary locally using torch.load(). From here, you can easily access the saved items by simply querying the dictionary as you would expect.
要加載模型,請首先初始化模型和優化器,然後使用 torch.load() 在本地加載字典。從這裏,您只需按期望查詢字典即可輕鬆訪問已保存的項目。

Remember that you must call model.eval() to set dropout and batch normalization layers to evaluation mode before running inference. Failing to do this will yield inconsistent inference results. If you wish to resuming training, call model.train() to set these layers to training mode.
請記住,在運行推理之前,必須調用 model.eval() 來將 dropout and batch normalization 層設置爲評估模式。不這樣做將產生不一致的推斷結果。如果您希望恢復訓練,請調用 model.train() 以確保這些層處於訓練模式。

5. Warmstarting Model Using Parameters from a Different Model (熱啓動)

Save:

torch.save(modelA.state_dict(), PATH)

Load:

modelB = TheModelBClass(*args, **kwargs)
modelB.load_state_dict(torch.load(PATH), strict=False)

Partially loading a model or loading a partial model are common scenarios when transfer learning or training a new complex model. Leveraging trained parameters, even if only a few are usable, will help to warmstart the training process and hopefully help your model converge much faster than training from scratch.
在遷移學習或訓練新的複雜模型時,部分加載模型或加載部分模型是常見方案。利用經過訓練的參數,即使只有少數幾個可用的參數,也將有助於熱啓動訓練過程,並希望與從頭開始訓練相比,可以更快地收斂模型。

Whether you are loading from a partial state_dict, which is missing some keys, or loading a state_dict with more keys than the model that you are loading into, you can set the strict argument to False in the load_state_dict() function to ignore non-matching keys.
無論是從缺少某些鍵的部分 state_dict 加載,還是要使用比要加載的模型更多的鍵加載 state_dict,可以在load_state_dict() 中將 strict 參數設置爲 False。忽略不匹配鍵的功能。

If you want to load parameters from one layer to another, but some keys do not match, simply change the name of the parameter keys in the state_dict that you are loading to match the keys in the model that you are loading into.
如果要將參數從一層加載到另一層,但是某些鍵不匹配,只需更改要加載的 state_dict 中參數鍵的名稱,以匹配要加載到的模型中的鍵。

6. Saving & Loading Model Across Devices

6.1 Save on GPU, Load on CPU

Save:

torch.save(model.state_dict(), PATH)

Load:

device = torch.device('cpu')
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location=device))

When loading a model on a CPU that was trained with a GPU, pass torch.device('cpu') to the map_location argument in the torch.load() function. In this case, the storages underlying the tensors are dynamically remapped to the CPU device using the map_location argument.
在 CPU 上加載 GPU 訓練的模型時,將 torch.device('cpu') 傳遞到 torch.load() 函數中的 map_location 參數。在這種情況下,使用 map_location 參數將張量下面的存儲動態地重新映射到 CPU 設備。

6.2 Save on GPU, Load on GPU

Save:

torch.save(model.state_dict(), PATH)

Load:

device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.to(device)
# Make sure to call input = input.to(device) on any input tensors that you feed to the model

When loading a model on a GPU that was trained and saved on GPU, simply convert the initialized model to a CUDA optimized model using model.to(torch.device('cuda')). Also, be sure to use the .to(torch.device('cuda')) function on all model inputs to prepare the data for the model. Note that calling my_tensor.to(device) returns a new copy of my_tensor on GPU. It does NOT overwrite my_tensor. Therefore, remember to manually overwrite tensors: my_tensor = my_tensor.to(torch.device('cuda')).
在 GPU 上加載 GPU 訓練保存的模型時,只需使用 model.to(torch.device('cuda')) 將已初始化的 model 轉換爲 CUDA 優化模型即可。同樣,請確保在所有模型輸入上使用 .to(torch.device('cuda')) 函數來爲模型準備數據。請注意,調用 my_tensor.to(device) 會在 GPU 上返回 my_tensor 的新副本。它不會覆蓋 my_tensor。因此,請記住手動覆蓋張量:my_tensor = my_tensor.to(torch.device('cuda'))

mytensor.to(device) 實際上是在 GPU 中創建了 mytensor 的副本, 而並沒有改變 mytensor 的值, 因此, 需要寫成後面的形式來使的 mytensor 的值改變:my_tensor = my_tensor.to(torch.device('cuda'))

6.3 Save on CPU, Load on GPU

Save:

torch.save(model.state_dict(), PATH)

Load:

device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location="cuda:0"))  # Choose whatever GPU device number you want
model.to(device)
# Make sure to call input = input.to(device) on any input tensors that you feed to the model

When loading a model on a GPU that was trained and saved on CPU, set the map_location argument in the torch.load() function to cuda:device_id. This loads the model to a given GPU device. Next, be sure to call model.to(torch.device('cuda')) to convert the model’s parameter tensors to CUDA tensors. Finally, be sure to use the .to(torch.device('cuda')) function on all model inputs to prepare the data for the CUDA optimized model. Note that calling my_tensor.to(device) returns a new copy of my_tensor on GPU. It does NOT overwrite my_tensor. Therefore, remember to manually overwrite tensors: my_tensor = my_tensor.to(torch.device('cuda')).
在 GPU 上加載 CPU 訓練保存的模型時,請將 torch.load() 函數中的 map_location 參數設置爲 cuda:device_id。這會將模型加載到指定的 GPU 設備。接下來,請確保調用 model.to(torch.device('cuda')) 將模型的參數張量轉換爲CUDA張量。最後,確保在所有模型輸入上使用 .to(torch.device('cuda')) 函數爲 CUDA 優化模型準備數據。請注意,調用 my_tensor.to(device) 會在 GPU 上返回 my_tensor 的新副本。它不會覆蓋 my_tensor。因此,請記住手動覆蓋張量:my_tensor = my_tensor.to(torch.device('cuda'))

7. Saving torch.nn.DataParallel Models

Save:

torch.save(model.module.state_dict(), PATH)

Load:

# Load to whatever device you want

torch.nn.DataParallel is a model wrapper that enables parallel GPU utilization. To save a DataParallel model generically, save the model.module.state_dict(). This way, you have the flexibility to load the model any way you want to any device you want.
torch.nn.DataParallel 是支持模型使用 GPU 並行的封裝器。要保存一個一般的 DataParallel 模型,請保存 model.module.state_dict()。這種方式,可以靈活地以任何方式加載模型到任何設備上。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章