Source code for mmcv.utils.registry

import inspect
import warnings
from functools import partial

from .misc import is_str


[docs]class Registry: """A registry to map strings to classes. Args: name (str): Registry name. """ def __init__(self, name): self._name = name self._module_dict = dict() def __len__(self): return len(self._module_dict) def __contains__(self, key): return self.get(key) is not None def __repr__(self): format_str = self.__class__.__name__ + \ f'(name={self._name}, ' \ f'items={self._module_dict})' return format_str @property def name(self): return self._name @property def module_dict(self): return self._module_dict
[docs] def get(self, key): """Get the registry record. Args: key (str): The class name in string format. Returns: class: The corresponding class. """ return self._module_dict.get(key, None)
def _register_module(self, module_class, module_name=None, force=False): if not inspect.isclass(module_class): raise TypeError('module must be a class, ' f'but got {type(module_class)}') if module_name is None: module_name = module_class.__name__ if not force and module_name in self._module_dict: raise KeyError(f'{module_name} is already registered ' f'in {self.name}') self._module_dict[module_name] = module_class def deprecated_register_module(self, cls=None, force=False): warnings.warn( 'The old API of register_module(module, force=False) ' 'is deprecated and will be removed, please use the new API ' 'register_module(name=None, force=False, module=None) instead.') if cls is None: return partial(self.deprecated_register_module, force=force) self._register_module(cls, force=force) return cls
[docs] def register_module(self, name=None, force=False, module=None): """Register a module. A record will be added to `self._module_dict`, whose key is the class name or the specified name, and value is the class itself. It can be used as a decorator or a normal function. Example: >>> backbones = Registry('backbone') >>> @backbones.register_module() >>> class ResNet: >>> pass >>> backbones = Registry('backbone') >>> @backbones.register_module(name='mnet') >>> class MobileNet: >>> pass >>> backbones = Registry('backbone') >>> class ResNet: >>> pass >>> backbones.register_module(ResNet) Args: name (str | None): The module name to be registered. If not specified, the class name will be used. force (bool, optional): Whether to override an existing class with the same name. Default: False. module (type): Module class to be registered. """ if not isinstance(force, bool): raise TypeError(f'force must be a boolean, but got {type(force)}') # NOTE: This is a walkaround to be compatible with the old api, # while it may introduce unexpected bugs. if isinstance(name, type): return self.deprecated_register_module(name, force=force) # use it as a normal method: x.register_module(module=SomeClass) if module is not None: self._register_module( module_class=module, module_name=name, force=force) return module # raise the error ahead of time if not (name is None or isinstance(name, str)): raise TypeError(f'name must be a str, but got {type(name)}') # use it as a decorator: @x.register_module() def _register(cls): self._register_module( module_class=cls, module_name=name, force=force) return cls return _register
[docs]def build_from_cfg(cfg, registry, default_args=None): """Build a module from config dict. Args: cfg (dict): Config dict. It should at least contain the key "type". registry (:obj:`Registry`): The registry to search the type from. default_args (dict, optional): Default initialization arguments. Returns: object: The constructed object. """ if not isinstance(cfg, dict): raise TypeError(f'cfg must be a dict, but got {type(cfg)}') if 'type' not in cfg: raise KeyError( f'the cfg dict must contain the key "type", but got {cfg}') if not isinstance(registry, Registry): raise TypeError('registry must be an mmcv.Registry object, ' f'but got {type(registry)}') if not (isinstance(default_args, dict) or default_args is None): raise TypeError('default_args must be a dict or None, ' f'but got {type(default_args)}') args = cfg.copy() obj_type = args.pop('type') if is_str(obj_type): obj_cls = registry.get(obj_type) if obj_cls is None: raise KeyError( f'{obj_type} is not in the {registry.name} registry') elif inspect.isclass(obj_type): obj_cls = obj_type else: raise TypeError( f'type must be a str or valid type, but got {type(obj_type)}') if default_args is not None: for name, value in default_args.items(): args.setdefault(name, value) return obj_cls(**args)