可以先定義兩個函數:
import torch.nn.init as init
def xavier(param):
init.xavier_uniform(param)
# init.kaiming_uniform_() # 可以選擇其他的
def weights_init(m):
if isinstance(m, nn.Conv2d):
xavier(m.weight.data)
m.bias.data.zero_()
初始化的時候可以直接調用這兩個函數:
net.loc.apply(weights_init)
net.conf.apply(weights_init)
或者:
net.vgg[3:].apply(weights_init)
也可用其他的方法,看自己習慣