Pytorch 版本的lookahead 優化函數使用(附代碼)

Lookahead 優化算法是Adam的作者繼Adam之後的又一力作,論文可以參見https://arxiv.org/abs/1907.08610

這篇博客先不講述Lookahead具體原理,先介紹如何將Lookahead集成到現有的代碼中。

本人在三個項目中(涉及風格轉換、物體識別)使用該優化器,最大的感受就是使用該優化器十分有利於模型收斂,原本不收斂或者收斂過慢的模型在使用lookahead後可以看到明顯的收斂情況,並且最終的效果能夠滿足最初設計的要求。

總所周知,Adam因爲其具有較好的適應性,被廣泛用於各類模型的優化;其參數簡單,調參方便的特點一直爲大家所喜愛,尤其對於初學者較爲友好。Lookahead 也繼承了Adam的優點。lookahead的Pytorch版本代碼如下所示:後續會針對代碼進行原理講解,該代碼在Github上可以找到。

from collections import defaultdict
from torch.optim import Optimizer
import torch


class Lookahead(Optimizer):
    def __init__(self, optimizer, k=5, alpha=0.5):
        self.optimizer = optimizer

        self.k = k
        self.alpha = alpha
        self.param_groups = self.optimizer.param_groups
        self.state = defaultdict(dict)
        self.fast_state = self.optimizer.state
        for group in self.param_groups:
            group["counter"] = 0

    def update(self, group):
        for fast in group["params"]:
            param_state = self.state[fast]
            if "slow_param" not in param_state:
                param_state["slow_param"] = torch.zeros_like(fast.data)
                param_state["slow_param"].copy_(fast.data)
            slow = param_state["slow_param"]
            slow += (fast.data - slow) * self.alpha
            fast.data.copy_(slow)

    def update_lookahead(self):
        for group in self.param_groups:
            self.update(group)

    def step(self, closure=None):
        loss = self.optimizer.step(closure)
        for group in self.param_groups:
            if group["counter"] == 0:
                self.update(group)
            group["counter"] += 1
            if group["counter"] >= self.k:
                group["counter"] = 0
        return loss

    def state_dict(self):
        fast_state_dict = self.optimizer.state_dict()
        slow_state = {
            (id(k) if isinstance(k, torch.Tensor) else k): v
            for k, v in self.state.items()
        }
        fast_state = fast_state_dict["state"]
        param_groups = fast_state_dict["param_groups"]
        return {
            "fast_state": fast_state,
            "slow_state": slow_state,
            "param_groups": param_groups,
        }

    def load_state_dict(self, state_dict):
        slow_state_dict = {
            "state": state_dict["slow_state"],
            "param_groups": state_dict["param_groups"],
        }
        fast_state_dict = {
            "state": state_dict["fast_state"],
            "param_groups": state_dict["param_groups"],
        }
        super(Lookahead, self).load_state_dict(slow_state_dict)
        self.optimizer.load_state_dict(fast_state_dict)
        self.fast_state = self.optimizer.state

    def add_param_group(self, param_group):
        param_group["counter"] = 0
        self.optimizer.add_param_group(param_group)

將lookahead集成在現有代碼中如下操作即可:

base_optimizer = Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.999))
opt = Lookahead(base_optimizer, k=5, alpha=0.5)

此時直接將opt作爲正常的優化器使用即可,就像直接使用Adam一樣的步驟使用opt

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章