import inspect
from functools import partial
from .misc import is_str
[docs]class Registry(object):
"""A registry to map strings to classes.
Args:
name (str): Registry name.
"""
def __init__(self, name):
self._name = name
self._module_dict = dict()
def __len__(self):
return len(self._module_dict)
def __contains__(self, key):
return self.get(key) is not None
def __repr__(self):
format_str = self.__class__.__name__ + '(name={}, items={})'.format(
self._name, list(self._module_dict.keys()))
return format_str
@property
def name(self):
return self._name
@property
def module_dict(self):
return self._module_dict
[docs] def get(self, key):
"""Get the registry record.
Args:
key (str): The class name in string format.
Returns:
class: The corresponding class.
"""
return self._module_dict.get(key, None)
def _register_module(self, module_class, force=False):
if not inspect.isclass(module_class):
raise TypeError('module must be a class, but got {}'.format(
type(module_class)))
module_name = module_class.__name__
if not force and module_name in self._module_dict:
raise KeyError('{} is already registered in {}'.format(
module_name, self.name))
self._module_dict[module_name] = module_class
[docs] def register_module(self, cls=None, force=False):
"""Register a module.
A record will be added to `self._module_dict`, whose key is the class
name and value is the class itself.
It can be used as a decorator or a normal function.
Example:
>>> backbones = Registry('backbone')
>>> @backbones.register_module
>>> class ResNet(object):
>>> pass
Example:
>>> backbones = Registry('backbone')
>>> class ResNet(object):
>>> pass
>>> backbones.register_module(ResNet)
Args:
module (:obj:`nn.Module`): Module to be registered.
force (bool, optional): Whether to override an existing class with
the same name. Default: False.
"""
if cls is None:
return partial(self.register_module, force=force)
self._register_module(cls, force=force)
return cls
[docs]def build_from_cfg(cfg, registry, default_args=None):
"""Build a module from config dict.
Args:
cfg (dict): Config dict. It should at least contain the key "type".
registry (:obj:`Registry`): The registry to search the type from.
default_args (dict, optional): Default initialization arguments.
Returns:
obj: The constructed object.
"""
if not (isinstance(cfg, dict) and 'type' in cfg):
raise TypeError('cfg must be a dict containing the key "type"')
if not isinstance(registry, Registry):
raise TypeError(
'registry must be an mmcv.Registry object, but got {}'.format(
type(registry)))
if not (isinstance(default_args, dict) or default_args is None):
raise TypeError(
'default_args must be a dict or None, but got {}'.format(
type(default_args)))
args = cfg.copy()
obj_type = args.pop('type')
if is_str(obj_type):
obj_cls = registry.get(obj_type)
if obj_cls is None:
raise KeyError('{} is not in the {} registry'.format(
obj_type, registry.name))
elif inspect.isclass(obj_type):
obj_cls = obj_type
else:
raise TypeError('type must be a str or valid type, but got {}'.format(
type(obj_type)))
if default_args is not None:
for name, value in default_args.items():
args.setdefault(name, value)
return obj_cls(**args)