torch.optim.optimizer源碼閱讀和靈活使用

optimizer是PyTorch更新模型參數的工具。PyTorch先定義一個基類Optimizer來實現優化器的基本功能,再用子類實現每一個優化算法相應的優化過程,如SGD、Adam等。

class Optimizer(object):
    r"""Base class for all optimizers.
    .. warning::
        Parameters need to be specified as collections that have a deterministic
        ordering that is consistent between runs. Examples of objects that don't
        satisfy those properties are sets and iterators over values of dictionaries.
    Arguments:
        params (iterable): an iterable of :class:`torch.Tensor` s or :class:`dict` s. Specifies what Tensors should be optimized.
        defaults: (dict): a dict containing default values of optimization
            options (used when a parameter group doesn't specify them).
    """

    def __init__(self, params, defaults):
        torch._C._log_api_usage_once("python.optimizer")
        
        # defaults是lr/momentun等對待優化變量有全局影響的參數,子類將其初始化爲字典
        self.defaults = defaults

        # params必須是由Tensor或字典構成的可迭代對象
        if isinstance(params, torch.Tensor):
            raise TypeError("params argument given to the optimizer should be "
                            "an iterable of Tensors or dicts, but got " +
                            torch.typename(params))
		
        # state是一個有默認值的字典,默認值類型爲字典;保存optimizer的當前狀態
        self.state = defaultdict(dict) 
        # self.param_groups保存所有待優化的參數;其中的每一項都是一個字典,對應一組待優化參數及相關的參數
        self.param_groups = []

        param_groups = list(params) # 所有要被optimizer優化的變量,不可爲空
        if len(param_groups) == 0:
            raise ValueError("optimizer got an empty parameter list")
        # 將被優化的變量以字典的形式保存爲列表中的一項;
        if not isinstance(param_groups[0], dict):
            param_groups = [{'params': param_groups}]

        # 將param_groups中的所有項添加到self.param_groups中
        for param_group in param_groups:
            self.add_param_group(param_group)

在構造函數中將所有待優化的參數以字典的形式保存到列表中,進而再添加到self.param_groups中。這樣做的目的是在fine-tune時,方便通過key-value的形式訪問相應的數據。下面來看 self.add_param_group():

def add_param_group(self, param_group):
        r"""Add a param group to the :class:`Optimizer` s `param_groups`.
        This can be useful when fine tuning a pre-trained network as frozen layers can be made trainable and added to the :class:`Optimizer` as training progresses.
        Arguments:
            param_group (dict): Specifies what Tensors should be optimized along with group specific optimization options.
        """
        assert isinstance(param_group, dict), "param group must be a dict"

        params = param_group['params']
        # 將所有的參數封裝成一個列表;此時'params'對應的value就是一個列表,其中是所有要被優化的變量
        if isinstance(params, torch.Tensor):
            param_group['params'] = [params] 
        elif isinstance(params, set):
            raise TypeError('optimizer parameters need to be organized in ordered collections, but the ordering of tensors in sets will change between runs. Please use a list instead.')
        else:
            param_group['params'] = list(params)

        # 待優化變量必須是torch.Tensor類型,且必須是葉節點(顯式定義的變量)
        for param in param_group['params']:
            if not isinstance(param, torch.Tensor):
                raise TypeError("optimizer can only optimize Tensors, but one of the params is " + torch.typename(param))
            if not param.is_leaf:
                raise ValueError("can't optimize a non-leaf Tensor")

        # 將其他參數添加爲字典中的一項
        for name, default in self.defaults.items():
            if default is required and name not in param_group:
                raise ValueError("parameter group didn't specify a value of required optimization parameter " + name)
            else:
                param_group.setdefault(name, default)
		
        # 藉助set來判斷'params'項是否已經存在於self.param_groups中
        param_set = set()
        for group in self.param_groups:
            param_set.update(set(group['params']))

        if not param_set.isdisjoint(set(param_group['params'])):
            raise ValueError("some parameters appear in more than one parameter group")
            # isdisjoint判斷兩個集合是否含有相同的元素,返回布爾值

        # 將所有相關參數添加到self.param_groups中
        self.param_groups.append(param_group)

接下來看參數保存和加載的兩個函數:

獲取optimizer的參數:state_dict

def state_dict(self):
        r"""Returns the state of the optimizer as a :class:`dict`.
        It contains two entries:
        * state - a dict holding current optimization state. Its content
            differs between optimizer classes.
        * param_groups - a dict containing all parameter groups
        """
        # Save ids instead of Tensors
        def pack_group(group):
            packed = {k: v for k, v in group.items() if k != 'params'}
            packed['params'] = [id(p) for p in group['params']]
            return packed
        # self.param_groups中的每一項(字典)重新以字典形式返回,並封裝在一個列表中,即param_groups的數據組織形式與self.param_groups完全相同,區別是'params'這一項數據不再是Tensor,而是Tensor的地址。即原來保存的是變量,現在保存的變量對應的對象的地址。
        param_groups = [pack_group(g) for g in self.param_groups]
        # 將state中的所有Tensor替換爲相應的對象的地址
        packed_state = {(id(k) if isinstance(k, torch.Tensor) else k): v
                        for k, v in self.state.items()}
        # 以字典的形式返回
        return {
            'state': packed_state,
            'param_groups': param_groups,
        }

在上式中需要注意的是,變量返回的都是對象地址,而不是變量值。

加載本地保存的參數:load_state_dict

def load_state_dict(self, state_dict):
        r"""Loads the optimizer state.
        Arguments:
            state_dict (dict): optimizer state. Should be an object returned from a call to :meth:`state_dict`.
        """
        # deepcopy, to be consistent with module API
        state_dict = deepcopy(state_dict)
        
        # 檢查當前optimizer的參數是否與要加載的數據一致
        groups = self.param_groups
        saved_groups = state_dict['param_groups']

        if len(groups) != len(saved_groups):
            raise ValueError("loaded state dict has a different number of parameter groups")
        param_lens = (len(g['params']) for g in groups)
        saved_lens = (len(g['params']) for g in saved_groups)
        # 檢查每一個'params'中的每一個變量
        if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)):
            raise ValueError("loaded state dict contains a parameter group that doesn't match the size of optimizer's group")

        # 以字典的形式建立舊對象地址和新對象地址的映射
        id_map = {old_id: p for old_id, p in
                  zip(chain(*(g['params'] for g in saved_groups)),
                      chain(*(g['params'] for g in groups)))}
        
        # dtype或device的轉換
        def cast(param, value):
            r"""Make a deep copy of value, casting all tensors to device of param."""
            if isinstance(value, torch.Tensor):
                # Floating-point types are a bit special here. They are the only ones
                # that are assumed to always match the type of params.
                if param.is_floating_point():
                    value = value.to(param.dtype)
                value = value.to(param.device)
                return value
            elif isinstance(value, dict):
                return {k: cast(param, v) for k, v in value.items()}
            elif isinstance(value, container_abcs.Iterable):
                return type(value)(cast(param, v) for v in value)
            else:
                return value

        # Copy state assigned to params (and cast tensors to appropriate types).
        # State that is not assigned to params is copied as is (needed for
        # backward compatibility).
        state = defaultdict(dict)
        for k, v in state_dict['state'].items():
            if k in id_map:  # 舊對象地址
                param = id_map[k]  # 新對象地址
                state[param] = cast(param, v)
            else:
                state[k] = v

        # 參數更新
        def update_group(group, new_group):
            new_group['params'] = group['params']
            return new_group
        param_groups = [update_group(g, ng) for g, ng in zip(groups, saved_groups)]
        self.__setstate__({'state': state, 'param_groups': param_groups})

其他函數:

清空梯度:zero_grad

# 清空所有待優化參數的梯度。由於pytorch中Tensor的梯度默認是累加的,故模型訓練時要正確計算每次反向傳播的梯度,都要對之前的梯度清零。
def zero_grad(self):
    r"""Clears the gradients of all optimized :class:`torch.Tensor` s."""
    for group in self.param_groups:
        for p in group['params']:
            if p.grad is not None:
                p.grad.detach_()
                p.grad.zero_()

單步執行:step

# 該函數由子類實現
def step(self, closure):
    r"""Performs a single optimization step (parameter update).
        Arguments:
            closure (callable): A closure that reevaluates the model and
                returns the loss. Optional for most optimizers.
        .. note::
            Unless otherwise specified, this function should not modify the
            ``.grad`` field of the parameters.
        """
    raise NotImplementedError

下面以SGD爲例,具體講解optimizer的原理:

SGD的一般用法爲:(代碼來源:pytorch/examples/imagenet/main.py

優化器定義:

optimizer = torch.optim.SGD(model.parameters(), args.lr, momentum=args.momentum,
                            weight_decay=args.weight_decay)

模型訓練時:

optimizer.zero_grad()  # 歷史梯度清零
output = model(image)  # 計算前向輸出
loss = criterion(output, target)  # 計算loss
loss.backward()   # 計算當前梯度
optimizer.step()  # 變量更新

下面來看optim.SGD的源碼:

class SGD(Optimizer):
    r"""Implements stochastic gradient descent (optionally with momentum).
    Nesterov momentum is based on the formula from
    `On the importance of initialization and momentum in deep learning`__.
    Args:
        params (iterable): iterable of parameters to optimize or dicts defining
            parameter groups
        lr (float): learning rate
        momentum (float, optional): momentum factor (default: 0)
        weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
        dampening (float, optional): dampening for momentum (default: 0)
        nesterov (bool, optional): enables Nesterov momentum (default: False)
    """

    def __init__(self, params, lr=required, momentum=0, dampening=0,
                 weight_decay=0, nesterov=False):
        
        # 對lr/momentum/weight_decay等參數進行檢查
        if lr is not required and lr < 0.0:
            raise ValueError("Invalid learning rate: {}".format(lr))
        if momentum < 0.0:
            raise ValueError("Invalid momentum value: {}".format(momentum))
        if weight_decay < 0.0:
            raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
		
        # 將除待優化的變量之外的參數封裝成一個字典,用於初始化父類中的defaults參數
        defaults = dict(lr=lr, momentum=momentum, dampening=dampening,
                        weight_decay=weight_decay, nesterov=nesterov)
        if nesterov and (momentum <= 0 or dampening != 0):
            raise ValueError("Nesterov momentum requires a momentum and zero dampening")
        super(SGD, self).__init__(params, defaults)

由上述代碼可以看出SGD的構造函數主要進行一些參數的檢查和封裝,主要的初始化操作由Optimizer類來完成。

SGD類中另一個自己單獨實現的函數:step

step()更據當前的梯度對變量進行更新。SGD的更新公式爲:

標準公式:wt+1=wt+lrλwJ(w)w_{t+1} = w_{t} + \text{lr} * \lambda * \bigtriangledown_{w}J(w)λ\lambda是weight_decay)

momentum:vt+1=mvtlrλwJ(w)v_{t+1} = m * v_{t} - \text{lr} * \lambda * \bigtriangledown_{w}J(w)wt+1=wtvt+1w_{t+1} = w_{t} - v_{t+1}

nesterov momentum:vt+1=mvtlrλwJ(wmvt)v_{t+1} = m * v_{t} - \text{lr} * \lambda * \bigtriangledown_{w}J(w - m*v_{t})wt+1=wtvt+1w_{t+1} = w_{t} - v_{t+1}

上述公式是通用公式,其他框架也是這樣實現的,但PyTorch中的實現略有不同,改變了學習率計算的位置,即:

vt+1=mvtλwJ(wmvt)v_{t+1} = m * v_{t} - \lambda * \bigtriangledown_{w}J(w - m*v_{t})

=mvtλwJ(w)+λwJ(mvt))= m * v_{t} - \lambda * \bigtriangledown_{w}J(w) + \lambda * \bigtriangledown_{w}J(m * v_{t}))

=mvtλwJ(w)+mλwJ(vt))= m * v_{t} - \lambda * \bigtriangledown_{w}J(w) + m * \lambda * \bigtriangledown_{w}J(v_{t}))

wt+1=wtlrvt+1w_{t+1} = w_{t} - \text{lr} * v_{t+1}

@torch.no_grad()
def step(self, closure=None):
    """Performs a single optimization step.
    Arguments:
        closure (callable, optional): A closure that reevaluates the model and returns the loss.
    """
    # 根據closure重新計算loss
    loss = None
    if closure is not None:
        with torch.enable_grad():
            loss = closure()

    # 根據計算好的變量的梯度對變量進行更新
    for group in self.param_groups:
        weight_decay = group['weight_decay']
        momentum = group['momentum']
        dampening = group['dampening']
        nesterov = group['nesterov']

        for p in group['params']:
            if p.grad is None:
                continue
            d_p = p.grad
            if weight_decay != 0:
                d_p = d_p.add(p, alpha=weight_decay)  # L2正則化
            if momentum != 0:
                param_state = self.state[p]
                if 'momentum_buffer' not in param_state:
                    # 歷史更新量v_{t}
                    buf = param_state['momentum_buffer'] = torch.clone(d_p).detach()
                else:
                    buf = param_state['momentum_buffer']
                    # v_{t+1} = m * v_{t} + (1 - dampening) * \bigtriangledown_{w}J(w)
                    # dampening的作用自己理解是控制梯度的大小,以防出現梯度爆炸;
                    # 使用nesterov時必須設爲0
                    buf.mul_(momentum).add_(d_p, alpha=1 - dampening)
                    # 上式同時通過buf對self.state進行了更新
                if nesterov:
                    # \bigtriangledown_{w}J(w) + m * v_{t+1} ??? PyTorch對原公式的修改
                    d_p = d_p.add(buf, alpha=momentum)
                else:
                    d_p = buf
			# w_{t+1} = w_{t} - \text{lr} * v_{t+1}
            p.add_(d_p, alpha=-group['lr'])
            # 注意此處使用的是改變對象值得add_方法。這樣optimizer對模型參數的更新可以在模型中體現出來。
			# 每輪循環只用了一個'params'和相應的defaults參數
    return loss

根據上述代碼可以看出優化器每個循環中都根據’params’這個key在字典中取相應的value進行更新,且相關參數也是與這個’params’對應的。這樣做的目的是更靈活的對模型參數進行優化。比如,我只想對模型中的部分參數進行正則化。例:(代碼來源:MetaPruning/mobilenetv2/evaluating/evaluate.py )

# split the weight parameter that need weight decay
all_parameters = model.parameters()
weight_parameters = []
for pname, p in model.named_parameters():
    if 'fc' in pname or 'conv1' in pname or 'pwconv' in pname:
        weight_parameters.append(p)
weight_parameters_id = list(map(id, weight_parameters))
other_parameters = list(filter(lambda p: id(p) not in weight_parameters_id, all_parameters))

# define the optimizer
optimizer = torch.optim.SGD(
    [{'params' : other_parameters},
     {'params' : weight_parameters, 'weight_decay' : args.weight_decay}],
    args.learning_rate,
    momentum=args.momentum,
)

在上述代碼中,只對’fc’、‘conv1’和‘pwconv’層中的變量做正則化,對其他變量不做正則化,則需要將兩部分變量分離,添加到optimizer中時需注意:

  1. 兩部分變量必須都包含在一個列表中,列表中的每一項是帶優化的一部分變量,且每一項都必須是字典,字典的key必須是’params’;
  2. 某一部分變量如果有專有的訓練參數,如上述代碼中的weight_decay,則該參數必須與該部分變量在一個字典中,且該參數的key必須與PyTorch中的相關定義相同;
  3. 不在列表中的其他參數如learning_rate和momentum則對所有待優化變量起作用。

再通過一個實例來更清晰地認識state_dict:

import torch
import torch.nn as nn

class net(nn.Module):
    def __init__(self):
        super(net, self).__init__()

        self.conv = nn.Sequential(
            nn.Conv2d(1, 1, 1, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(1),
            nn.ReLU(inplace=True),
        )
        self.fc = nn.Linear(4, 2)

    def forward(self, x):
        x = self.conv(x)
        x = x.view(-1, 4)
        x = self.fc(x)
        return x

model = net()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD([{'params':model.conv.parameters(), 'lr':0.01},
                            {'params':model.fc.parameters(), 'lr':0.02}],
                            momentum=0.9,
                            weight_decay=1e-5)

x = torch.rand(1, 1, 2, 2)
y = torch.tensor([1])

optimizer.zero_grad()
output = model(x)
loss = criterion(output, y)
loss.backward()
optimizer.step()

print(optimizer.state_dict())

輸出結果爲:

{
    'state': {
        2387458897384: {'momentum_buffer': tensor([[[[-0.0003]]]])}, # shape:[1, 1, 1, 1]
        2387458897456: {'momentum_buffer': tensor([-0.1267])}, # shape:[1]
        2387458897528: {'momentum_buffer': tensor([-0.2534])}, # shape:[1]
        2387458897816: {'momentum_buffer':tensor([[2.7221e-01, 5.9235e-01, 1.5873e-06, 5.2255e-07],
                                                  [-2.7221e-01, -5.9235e-01, -4.2629e-06, -4.6668e-06]])}, # shape:[2, 4]
        2387458897888: {'momentum_buffer': tensor([ 0.5391, -0.5391])}}, # shape:[2]
    'param_groups': [
        {
            'lr': 0.01, 
            'momentum': 0.9, 
            'dampening': 0, 
            'weight_decay': 1e-05, 
            'nesterov': False, 
            'params': [2387458897384, 2387458897456,2387458897528]
        }, 
        {
            'lr': 0.02, 
            'momentum': 0.9, 
            'dampening': 0, 
            'weight_decay': 1e-05, 
            'nesterov': False, 
            'params': [2387458897816, 2387458897888]
        }
    ]
}

由上述結果可知:

  1. state_dict是一個字典,包含’state’和’param_group’兩項。
  2. 'state’是一個字典,其中保存的是optimizer更新變量過程中計算出的最新的相關緩存變量。key是這些緩存的地址,value也是一個字典,key是緩存變量名,value是相應的tensor。
  3. ‘param_groups’是一個列表,列表中的每一項是一個字典,表示一組待優化的變量及其相關更新參數。在每一項中,key是相應的變量名,value是對應的值。需要注意的是,所有待優化的變量以地址的形式保存在一個列表中,對應的key是‘params’。
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章