Unexpected key(s) in state_dict: “dense_block1.denselayer1.norm.1

 

Unexpected key(s) in state_dict: "dense_block1.denselayer1.norm.1

 

from torchvision.models import densenet121
from collections import OrderedDict

model = densenet121(pretrained=False)

state_dict =torch.load(model_weight_path)
# 初始化一個空 dict
new_state_dict = OrderedDict()
# 修改 key
for k, v in state_dict.items():
    if 'denseblock' in k:
        param = k.split(".")
        k = ".".join(param[:-3] + [param[-3]+param[-2]] + [param[-1]])
    new_state_dict[k] = v
    model.load_state_dict(new_state_dict)

我的解決方法:

# 初始化一個空 dict
new_state_dict = OrderedDict()
# 修改 key
for k, v in state_dict.items():
    k=k.replace('module.', '')
    if 'dense_block' in k:
        if "norm" in k or "conv.1" in k or "conv.2" in k:
            param = k.split(".")
            k = ".".join(param[:-3] + [param[-3]+param[-2]] + [param[-1]])
        new_state_dict[k] = v
    else:
        new_state_dict[k] = v

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章