Source code for mmcv.runner.epoch_based_runner

# Copyright (c) Open-MMLab. All rights reserved.
import os.path as osp
import time
import warnings

import torch

import mmcv
from .base_runner import BaseRunner
from .checkpoint import save_checkpoint
from .utils import get_host_info


[docs]class EpochBasedRunner(BaseRunner): """Epoch-based Runner. This runner train models epoch by epoch. """ def train(self, data_loader, **kwargs): self.model.train() self.mode = 'train' self.data_loader = data_loader self._max_iters = self._max_epochs * len(data_loader) self.call_hook('before_train_epoch') time.sleep(2) # Prevent possible deadlock during epoch transition for i, data_batch in enumerate(data_loader): self._inner_iter = i self.call_hook('before_train_iter') if self.batch_processor is None: outputs = self.model.train_step(data_batch, self.optimizer, **kwargs) else: outputs = self.batch_processor( self.model, data_batch, train_mode=True, **kwargs) if not isinstance(outputs, dict): raise TypeError('"batch_processor()" or "model.train_step()"' ' must return a dict') if 'log_vars' in outputs: self.log_buffer.update(outputs['log_vars'], outputs['num_samples']) self.outputs = outputs self.call_hook('after_train_iter') self._iter += 1 self.call_hook('after_train_epoch') self._epoch += 1 def val(self, data_loader, **kwargs): self.model.eval() self.mode = 'val' self.data_loader = data_loader self.call_hook('before_val_epoch') time.sleep(2) # Prevent possible deadlock during epoch transition for i, data_batch in enumerate(data_loader): self._inner_iter = i self.call_hook('before_val_iter') with torch.no_grad(): if self.batch_processor is None: outputs = self.model.val_step(data_batch, self.optimizer, **kwargs) else: outputs = self.batch_processor( self.model, data_batch, train_mode=False, **kwargs) if not isinstance(outputs, dict): raise TypeError('"batch_processor()" or "model.val_step()"' ' must return a dict') if 'log_vars' in outputs: self.log_buffer.update(outputs['log_vars'], outputs['num_samples']) self.outputs = outputs self.call_hook('after_val_iter') self.call_hook('after_val_epoch')
[docs] def run(self, data_loaders, workflow, max_epochs, **kwargs): """Start running. Args: data_loaders (list[:obj:`DataLoader`]): Dataloaders for training and validation. workflow (list[tuple]): A list of (phase, epochs) to specify the running order and epochs. E.g, [('train', 2), ('val', 1)] means running 2 epochs for training and 1 epoch for validation, iteratively. max_epochs (int): Total training epochs. """ assert isinstance(data_loaders, list) assert mmcv.is_list_of(workflow, tuple) assert len(data_loaders) == len(workflow) self._max_epochs = max_epochs for i, flow in enumerate(workflow): mode, epochs = flow if mode == 'train': self._max_iters = self._max_epochs * len(data_loaders[i]) break work_dir = self.work_dir if self.work_dir is not None else 'NONE' self.logger.info('Start running, host: %s, work_dir: %s', get_host_info(), work_dir) self.logger.info('workflow: %s, max: %d epochs', workflow, max_epochs) self.call_hook('before_run') while self.epoch < max_epochs: for i, flow in enumerate(workflow): mode, epochs = flow if isinstance(mode, str): # self.train() if not hasattr(self, mode): raise ValueError( f'runner has no method named "{mode}" to run an ' 'epoch') epoch_runner = getattr(self, mode) else: raise TypeError( 'mode in workflow must be a str, but got {}'.format( type(mode))) for _ in range(epochs): if mode == 'train' and self.epoch >= max_epochs: break epoch_runner(data_loaders[i], **kwargs) time.sleep(1) # wait for some hooks like loggers to finish self.call_hook('after_run')
[docs] def save_checkpoint(self, out_dir, filename_tmpl='epoch_{}.pth', save_optimizer=True, meta=None, create_symlink=True): """Save the checkpoint. Args: out_dir (str): The directory that checkpoints are saved. filename_tmpl (str, optional): The checkpoint filename template, which contains a placeholder for the epoch number. Defaults to 'epoch_{}.pth'. save_optimizer (bool, optional): Whether to save the optimizer to the checkpoint. Defaults to True. meta (dict, optional): The meta information to be saved in the checkpoint. Defaults to None. create_symlink (bool, optional): Whether to create a symlink "latest.pth" to point to the latest checkpoint. Defaults to True. """ if meta is None: meta = dict(epoch=self.epoch + 1, iter=self.iter) else: meta.update(epoch=self.epoch + 1, iter=self.iter) filename = filename_tmpl.format(self.epoch + 1) filepath = osp.join(out_dir, filename) optimizer = self.optimizer if save_optimizer else None save_checkpoint(self.model, filepath, optimizer=optimizer, meta=meta) # in some environments, `os.symlink` is not supported, you may need to # set `create_symlink` to False if create_symlink: mmcv.symlink(filename, osp.join(out_dir, 'latest.pth'))
[docs]class Runner(EpochBasedRunner): """Deprecated name of EpochBasedRunner.""" def __init__(self, *args, **kwargs): warnings.warn( 'Runner was deprecated, please use EpochBasedRunner instead') super().__init__(*args, **kwargs)