pytorch中網絡參數初始化

可以先定義兩個函數:

import torch.nn.init as init

def xavier(param):
    init.xavier_uniform(param)
    # init.kaiming_uniform_()    # 可以選擇其他的


def weights_init(m):
    if isinstance(m, nn.Conv2d):
        xavier(m.weight.data)
        m.bias.data.zero_()

初始化的時候可以直接調用這兩個函數:

net.loc.apply(weights_init)
net.conf.apply(weights_init)

或者:

net.vgg[3:].apply(weights_init)

也可用其他的方法,看自己習慣

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章