Source code for mmcv.runner.hooks.checkpoint

# Copyright (c) Open-MMLab. All rights reserved.
import os

from ..dist_utils import master_only
from .hook import HOOKS, Hook


[docs]@HOOKS.register_module() class CheckpointHook(Hook): """Save checkpoints periodically. Args: interval (int): The saving period. If ``by_epoch=True``, interval indicates epochs, otherwise it indicates iterations. Default: -1, which means "never". by_epoch (bool): Saving checkpoints by epoch or by iteration. Default: True. save_optimizer (bool): Whether to save optimizer state_dict in the checkpoint. It is usually used for resuming experiments. Default: True. out_dir (str, optional): The directory to save checkpoints. If not specified, ``runner.work_dir`` will be used by default. max_keep_ckpts (int, optional): The maximum checkpoints to keep. In some cases we want only the latest few checkpoints and would like to delete old ones to save the disk space. Default: -1, which means unlimited. """ def __init__(self, interval=-1, by_epoch=True, save_optimizer=True, out_dir=None, max_keep_ckpts=-1, **kwargs): self.interval = interval self.by_epoch = by_epoch self.save_optimizer = save_optimizer self.out_dir = out_dir self.max_keep_ckpts = max_keep_ckpts self.args = kwargs @master_only def after_train_epoch(self, runner): if not self.by_epoch or not self.every_n_epochs(runner, self.interval): return runner.logger.info(f'Saving checkpoint at {runner.epoch + 1} epochs') if not self.out_dir: self.out_dir = runner.work_dir runner.save_checkpoint( self.out_dir, save_optimizer=self.save_optimizer, **self.args) # remove other checkpoints if self.max_keep_ckpts > 0: filename_tmpl = self.args.get('filename_tmpl', 'epoch_{}.pth') current_epoch = runner.epoch + 1 for epoch in range(current_epoch - self.max_keep_ckpts, 0, -1): ckpt_path = os.path.join(self.out_dir, filename_tmpl.format(epoch)) if os.path.exists(ckpt_path): os.remove(ckpt_path) else: break @master_only def after_train_iter(self, runner): if self.by_epoch or not self.every_n_iters(runner, self.interval): return runner.logger.info( f'Saving checkpoint at {runner.iter + 1} iterations') if not self.out_dir: self.out_dir = runner.work_dir runner.save_checkpoint( self.out_dir, save_optimizer=self.save_optimizer, **self.args) # remove other checkpoints if self.max_keep_ckpts > 0: filename_tmpl = self.args.get('filename_tmpl', 'iter_{}.pth') current_iter = runner.iter + 1 for _iter in range( current_iter - self.max_keep_ckpts * self.interval, 0, -self.interval): ckpt_path = os.path.join(self.out_dir, filename_tmpl.format(_iter)) if os.path.exists(ckpt_path): os.remove(ckpt_path) else: break