Source code for mmcv.runner.hooks.logger.wandb

# Copyright (c) Open-MMLab. All rights reserved.
import numbers

from ...dist_utils import master_only
from ..hook import HOOKS
from .base import LoggerHook


[docs]@HOOKS.register_module() class WandbLoggerHook(LoggerHook): def __init__(self, init_kwargs=None, interval=10, ignore_last=True, reset_flag=True): super(WandbLoggerHook, self).__init__(interval, ignore_last, reset_flag) self.import_wandb() self.init_kwargs = init_kwargs def import_wandb(self): try: import wandb except ImportError: raise ImportError( 'Please run "pip install wandb" to install wandb') self.wandb = wandb @master_only def before_run(self, runner): if self.wandb is None: self.import_wandb() if self.init_kwargs: self.wandb.init(**self.init_kwargs) else: self.wandb.init() @master_only def log(self, runner): metrics = {} for var, val in runner.log_buffer.output.items(): if var in ['time', 'data_time']: continue tag = f'{var}/{runner.mode}' if isinstance(val, numbers.Number): metrics[tag] = val metrics['learning_rate'] = runner.current_lr()[0] metrics['momentum'] = runner.current_momentum()[0] if metrics: self.wandb.log(metrics, step=runner.iter) @master_only def after_run(self, runner): self.wandb.join()