Pytorch模型訓練(3) - 模型保存與加載

《模型保存與加載》
  本系列來總結Pytorch訓練中的模型結構一些內容,包括模型的定義,模型參數化初始化方法,模型的保存與加載等

0 博客目錄

Pytorch模型訓練(0) - CPN源碼解析
Pytorch模型訓練(1) - 模型定義
Pytorch模型訓練(2) - 模型初始化
Pytorch模型訓練(3) - 模型保存與加載
Pytorch模型訓練(4) - Loss Function
Pytorch模型訓練(5) - Optimizer
Pytorch模型訓練(6) - 數據加載

1 保存和加載

1.1 Save源碼

  Save使用pickle工具將模型對象序列化爲pickle文件到disk

def save(obj, f, pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL):
    """Saves an object to a disk file.  保存模型到disk
    See also: :ref:`recommend-saving-models`
    Args:
        obj: saved object
        f: a file-like object (has to implement write and flush) or a string
           containing a file name    保存模型的文件對象或文件名
        pickle_module: module used for pickling metadata and objects     使用python的pickle格式序列化模型
        pickle_protocol: can be specified to override the default protocol   pickle協議
    .. warning::
        If you are using Python 2, torch.save does NOT support StringIO.StringIO
        as a valid file-like object. This is because the write method should return
        the number of bytes written; StringIO.write() does not do this.
        Please use something like io.BytesIO instead.
        python2不支持StringIO.StringIO作爲文件對象,因爲其StringIO.write()不能返回write方法需要的寫入字節個數
        但可用io.BytesIO
    Example:
        >>> # Save to file
        >>> x = torch.tensor([0, 1, 2, 3, 4])
        >>> torch.save(x, 'tensor.pt')
        >>> # Save to io.BytesIO buffer
        >>> buffer = io.BytesIO()
        >>> torch.save(x, buffer)
    """
    調用底層_save方法,略微複雜,不繼續探討
    return _with_file_like(f, "wb", lambda f: _save(obj, f, pickle_module, pickle_protocol))

  使用這個save函數可以保存各種對象的模型、張量和字典;一般Pytorch保存模型後綴爲:.pt 或 .pth 或 .pkl

1.2 Load源碼

   Load使用pickle的unpickle工具將pickle的對象文件反序列化爲內存

def load(f, map_location=None, pickle_module=pickle, **pickle_load_args):
    """
    User extensions can register their own location tags and tagging and
    deserialization methods using `register_package`.
    Args:
    	文件對象或文件名
        f: a file-like object (has to implement read, readline, tell, and seek),
            or a string containing a file name    
      
        一個函數: 可以是torch.device,字符串,指定的重映射位置 
        可以用來指定加載模型到GPU或CPU等, 默認GPU       
        map_location: a function, torch.device, string or a dict specifying how to remap storage locations 
         
        pickle格式類型:這裏應該時反pickle序列化
        pickle_module: module used for unpickling metadata and objects (has to
            match the pickle_module used to serialize file)
         
        可選字段:比如 ``encoding=...``  在版本切換種,編碼衝突可用
        pickle_load_args: optional keyword arguments passed over to
            ``pickle_module.load`` and ``pickle_module.Unpickler``, e.g.,
            ``encoding=...``.
    .. note::
        When you call :meth:`torch.load()` on a file which contains GPU tensors, those tensors
        will be loaded to GPU by default. You can call `torch.load(.., map_location='cpu')`
        and then :meth:`load_state_dict` to avoid GPU RAM surge when loading a model checkpoint.
    .. note::
        In Python 3, when loading files saved by Python 2, you may encounter
        ``UnicodeDecodeError: 'ascii' codec can't decode byte 0x...``. This is
        caused by the difference of handling in byte strings in Python2 and
        Python 3. You may use extra ``encoding`` keyword argument to specify how
        these objects should be loaded, e.g., ``encoding='latin1'`` decodes them
        to strings using ``latin1`` encoding, and ``encoding='bytes'`` keeps them
        as byte arrays which can be decoded later with ``byte_array.decode(...)``.
    Example:
    	#默認加載到GPU
        >>> torch.load('tensors.pt')
      
        # Load all tensors onto the CPU
        加載到CPU
        >>> torch.load('tensors.pt', map_location=torch.device('cpu'))
        
        # Load all tensors onto the CPU, using a function
        用函數加載到CPU
        >>> torch.load('tensors.pt', map_location=lambda storage, loc: storage)
        
        # Load all tensors onto GPU 1
        加載到GPU1
        >>> torch.load('tensors.pt', map_location=lambda storage, loc: storage.cuda(1))
        
        # Map tensors from GPU 1 to GPU 0
        從GPU1映射到GPU0
        >>> torch.load('tensors.pt', map_location={'cuda:1':'cuda:0'})
        
        # Load tensor from io.BytesIO object
        從 io.BytesIO對象加載
        >>> with open('tensor.pt') as f:
                buffer = io.BytesIO(f.read())
        >>> torch.load(buffer)
    """
    new_fd = False
    if isinstance(f, str) or \
            (sys.version_info[0] == 2 and isinstance(f, unicode)) or \
            (sys.version_info[0] == 3 and isinstance(f, pathlib.Path)):
        new_fd = True
        f = open(f, 'rb')
    try:
        return _load(f, map_location, pickle_module, **pickle_load_args)
    finally:
        if new_fd:
            f.close()

2 一般形式

  從源碼不難看出pytorch保存模型的方式多樣,保存模型的後綴名也是多樣的,但要注意使用哪種保存,就要使用對應的加載方式
  一般我們常用到Pytorch加載和保存模型方式有以下幾種種:

2.1 保存整個網絡

torch.save(model, PATH) 

model=torch.load(PATH)

  這種方式重新加載的時候不需要自定義網絡結構,保存時已經把網絡結構保存了下來

2.2 保存網絡參數

  這種方式,速度快,佔空間少

torch.save(model.state_dict(),PATH)

model.load_state_dict(torch.load(PATH))

或者
torch.save(model.module.state_dict(), final_model_state_file)

model.module.load_state_dict(torch.load(final_model_state_file))

  僅保存和加載模型參數,這種方式重新加載的時候需要自己定義網絡model,並且其中的參數名稱與結構要與保存的模型中的一致(可以是部分網絡,比如只使用VGG的前幾層),相對靈活,便於對網絡進行修改

2.3 保存更多參數

  在實驗中往往需要保存更多的信息,比如優化器的參數,那麼可以採取下面的方法保存:

torch.save({
	'epoch': epochID + 1, 
	'state_dict': model.state_dict(), 
	'best_loss': lossMIN,
    'optimizer': optimizer.state_dict(),
    'alpha': loss.alpha, 
    'gamma': loss.gamma
    },checkpoint_path + '/m-' + launchTimestamp + '-' + str("%.4f" % lossMIN) + '.pth.tar')

  以上包含的信息有,epochID, state_dict, min loss, optimizer, 自定義損失函數的兩個參數;格式以字典的格式存儲。對應加載的方式:

def load_checkpoint(model, checkpoint_PATH, optimizer):
    if checkpoint != None:
        model_CKPT = torch.load(checkpoint_PATH)
        model.load_state_dict(model_CKPT['state_dict'])
        print('loading checkpoint!')
        optimizer.load_state_dict(model_CKPT['optimizer'])
    return model, optimizer

  但是,我們可能修改了一部分網絡,比如加了一些,刪除一些,等等,那麼需要過濾這些參數,加載方式:

def load_checkpoint(model, checkpoint, optimizer, loadOptimizer):
    if checkpoint != 'No':
        print("loading checkpoint...")
        model_dict = model.state_dict()
        modelCheckpoint = torch.load(checkpoint)
        pretrained_dict = modelCheckpoint['state_dict']
        # 過濾操作
        new_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict.keys()}
        model_dict.update(new_dict)
        # 打印出來,更新了多少的參數
        print('Total : {}, update: {}'.format(len(pretrained_dict), len(new_dict)))
        model.load_state_dict(model_dict)
        print("loaded finished!")
        # 如果不需要更新優化器那麼設置爲false
        if loadOptimizer == True:
            optimizer.load_state_dict(modelCheckpoint['optimizer'])
            print('loaded! optimizer')
        else:
            print('not loaded optimizer')
    else:
        print('No checkpoint is included')
    return model, optimizer

3 CPN

3.1 CPN模型保存–train

 save_model({
        'epoch': epoch + 1,
        'state_dict': model.state_dict(),
        'optimizer' : optimizer.state_dict(),
    }, checkpoint=args.checkpoint)

  保存了一些必要訓練參數和模型參數

3.2 CPN模型加載–test

 checkpoint_file = os.path.join(args.checkpoint, args.test+'.pth.tar')
 checkpoint = torch.load(checkpoint_file)
 model.load_state_dict(checkpoint['state_dict'])
 print("=> loaded checkpoint '{}' (epoch {})".format(checkpoint_file, checkpoint['epoch']))

  測試模型時,我們只關注模型參數

3.3 CPN模型加載–resume

    if args.resume:
        if isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            pretrained_dict = checkpoint['state_dict']
            model.load_state_dict(pretrained_dict)
            args.start_epoch = checkpoint['epoch']
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))
            logger = Logger(join(args.checkpoint, 'log.txt'), resume=True)
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
    else:        
        logger = Logger(join(args.checkpoint, 'log.txt'))
        logger.set_names(['Epoch', 'LR', 'Train Loss'])

  resume是指接着某一次保存的模型繼續訓練,因爲我們在訓練中,可能中斷或需要調調參數,就可以用這種方式;一般來說,它需要保存模型時保存當時的訓練現場,就像caffe訓練時保存的solverstate文件

3.4 CPN模型加載–finetuning

def resnet50(pretrained=False, **kwargs):
    """Constructs a ResNet-50 model.
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
    if pretrained:
        print('Initialize with pre-trained ResNet')
        from collections import OrderedDict
        state_dict = model.state_dict()
        pretrained_state_dict = model_zoo.load_url(model_urls['resnet50'])
        for k, v in pretrained_state_dict.items():
            if k not in state_dict:
                continue
            state_dict[k] = v
        print('successfully load '+str(len(state_dict.keys()))+' keys')
        model.load_state_dict(state_dict)
    return model

  finetuning與resume之間還是有點區別的;我們常常說的finetuning(遷移學習)本質就是加載預訓練,繼續訓練;當然加載時,可能會根據需求選擇參數,也可能會適當凍結部分參數等

4 細節補充

   1)model.state_dict
  pytorch 中的 state_dict 是一個簡單的python的字典對象;在模型中,它將每一層與它的對應參數建立映射關係,如model的每一層的weights及偏置等等
  注意:只有那些參數可以訓練的layer纔會被保存到模型的state_dict中,如卷積層,線性層等等
  優化器對象Optimizer也有一個state_dict,它包含了優化器的狀態以及被使用的超參數,如lr, momentum,weight_decay等

   2)OrderedDict
  collections模塊中的有序字典;模型中,大部分字典對象都是用它,如Sequential:

# Example of using Sequential
model = nn.Sequential(
          nn.Conv2d(1,20,5),
          nn.ReLU(),
          nn.Conv2d(20,64,5),
          nn.ReLU()
        )
# Example of using Sequential with OrderedDict
model = nn.Sequential(OrderedDict([
          ('conv1', nn.Conv2d(1,20,5)),
          ('relu1', nn.ReLU()),
          ('conv2', nn.Conv2d(20,64,5)),
          ('relu2', nn.ReLU())
        ]))

  在Python中,dict這個數據結構由於hash的特性,是無序的,這在有的時候會給我們帶來一些麻煩, 幸運的是,collections模塊爲我們提供了OrderedDict,當你要獲得一個有序的字典對象時,用它就對了

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章