Source code for mmcv.utils.registry

import inspect
from functools import partial

from .misc import is_str


[docs]class Registry(object): """A registry to map strings to classes. Args: name (str): Registry name. """ def __init__(self, name): self._name = name self._module_dict = dict() def __len__(self): return len(self._module_dict) def __contains__(self, key): return self.get(key) is not None def __repr__(self): format_str = self.__class__.__name__ + '(name={}, items={})'.format( self._name, list(self._module_dict.keys())) return format_str @property def name(self): return self._name @property def module_dict(self): return self._module_dict
[docs] def get(self, key): """Get the registry record. Args: key (str): The class name in string format. Returns: class: The corresponding class. """ return self._module_dict.get(key, None)
def _register_module(self, module_class, force=False): if not inspect.isclass(module_class): raise TypeError('module must be a class, but got {}'.format( type(module_class))) module_name = module_class.__name__ if not force and module_name in self._module_dict: raise KeyError('{} is already registered in {}'.format( module_name, self.name)) self._module_dict[module_name] = module_class
[docs] def register_module(self, cls=None, force=False): """Register a module. A record will be added to `self._module_dict`, whose key is the class name and value is the class itself. It can be used as a decorator or a normal function. Example: >>> backbones = Registry('backbone') >>> @backbones.register_module >>> class ResNet(object): >>> pass Example: >>> backbones = Registry('backbone') >>> class ResNet(object): >>> pass >>> backbones.register_module(ResNet) Args: module (:obj:`nn.Module`): Module to be registered. force (bool, optional): Whether to override an existing class with the same name. Default: False. """ if cls is None: return partial(self.register_module, force=force) self._register_module(cls, force=force) return cls
[docs]def build_from_cfg(cfg, registry, default_args=None): """Build a module from config dict. Args: cfg (dict): Config dict. It should at least contain the key "type". registry (:obj:`Registry`): The registry to search the type from. default_args (dict, optional): Default initialization arguments. Returns: obj: The constructed object. """ if not (isinstance(cfg, dict) and 'type' in cfg): raise TypeError('cfg must be a dict containing the key "type"') if not isinstance(registry, Registry): raise TypeError( 'registry must be an mmcv.Registry object, but got {}'.format( type(registry))) if not (isinstance(default_args, dict) or default_args is None): raise TypeError( 'default_args must be a dict or None, but got {}'.format( type(default_args))) args = cfg.copy() obj_type = args.pop('type') if is_str(obj_type): obj_cls = registry.get(obj_type) if obj_cls is None: raise KeyError('{} is not in the {} registry'.format( obj_type, registry.name)) elif inspect.isclass(obj_type): obj_cls = obj_type else: raise TypeError('type must be a str or valid type, but got {}'.format( type(obj_type))) if default_args is not None: for name, value in default_args.items(): args.setdefault(name, value) return obj_cls(**args)