簡單理解mmdetection中的registry(註冊類)類【轉載整理】

簡單理解mmdetection中的registry(註冊類)類【轉載整理】

原文地址:簡單理解mmdetection中的registry類

         在mmdetection中使用registry類構建九個註冊類實例,其實就是對類做一個劃分管理,其中的每一個實例都用於存放屬於這一簇的類,將來通過get_key方式獲取,key 來自於config文件。

          比如,backbone 作爲一簇其中包括vgg、resnet等。九個註冊類實例如下所示:

# 【1】文件:mmdet\models\registry.py #
BACKBONES = Registry('backbone')
NECKS = Registry('neck')
ROI_EXTRACTORS = Registry('roi_extractor')
SHARED_HEADS = Registry('shared_head')
HEADS = Registry('head')
LOSSES = Registry('loss')
DETECTORS = Registry('detector')


# 【2】文件:mmdet\datasets\registry.py #
DATASETS = Registry('dataset')
PIPELINES = Registry('pipeline')

          Registry 類的代碼如下所示:

#!/usr/bin/python3
# -*- coding: utf-8 -*-
import inspect

class Registry(object):
    # 【1】
    def __init__(self, name):
        self._name = name
        self._module_dict = dict()

    # 【2】
    def __repr__(self):
        format_str = self.__class__.__name__ + '(name={}, items={})'.format(
            self._name, list(self._module_dict.keys()))
        return format_str

    # 【3】
    @property
    def name(self):
        return self._name

    # 【4】
    @property
    def module_dict(self):
        return self._module_dict

    # 【5】
    def get(self, key):
        return self._module_dict.get(key, None)

    # 【6】
    def _register_module(self, module_class):
        """
        Register a module.
        Args:
            module (:obj:`nn.Module`): Module to be registered.
        """
        if not inspect.isclass(module_class):
            raise TypeError('module must be a class, but got {}'.format(
                type(module_class)))
        module_name = module_class.__name__
        if module_name in self._module_dict:
            raise KeyError('{} is already registered in {}'.format(
                module_name, self.name))
        self._module_dict[module_name] = module_class

    # 【7】
    def register_module(self, cls):
        self._register_module(cls)
        return cls

          mmdetection在構建模型的過程中一直通過key在註冊類中查找對應的類,找到對應的類後將其實例化,最終將配置描述的模型構建出來,查找類的過程如以下代碼所示:

key = 'vgg'
VGG = BACKBONES.get(key)

key = 'bce'
BCE = LOSSES .get(key)

          在mmdetection的代碼中,將一個類註冊到一個註冊器中,可以直接將裝飾器的語法糖寫在類的聲明上方,關於裝飾器的內容參考上一篇博客。

ANIMAL = Registry('animal')

@ANIMAL.register_module
class Dog(object):
    def __init__(self):
        pass

    def run(self):
        print('running dog')

dog = ANIMAL.get('Dog')

d = dog()
d.run()

          等價寫法爲:

ANIMAL = Registry('animal')

class Dog(object):
    def __init__(self):
        pass

    def run(self):
        print('running dog')

ANIMAL.register_module(Dog)
dog = ANIMAL.get('Dog')

d = dog()
d.run()

          兩者輸出結果皆爲:

running dog

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章