Source code for mmcv.cnn.bricks.non_local

from abc import ABCMeta

import torch
import torch.nn as nn

from ..utils import constant_init, normal_init
from .conv_module import ConvModule
from .registry import PLUGIN_LAYERS


class _NonLocalNd(nn.Module, metaclass=ABCMeta):
    """Basic Non-local module.

    This module is proposed in
    "Non-local Neural Networks"
    Paper reference: https://arxiv.org/abs/1711.07971

    Args:
        in_channels (int): Channels of the input feature map.
        reduction (int): Channel reduction ratio. Default: 2.
        use_scale (bool): Whether to scale pairwise_weight by
            `1/sqrt(inter_channels)` when the mode is `embedded_gaussian`.
            Default: True.
        conv_cfg (None | dict): The config dict for convolution layers.
            If not specified, it will use `nn.Conv2d` for convolution layers.
            Default: None.
        norm_cfg (None | dict): The config dict for normalization layers.
            Default: None. (This parameter is only applicable to conv_out.)
        mode (str): Options are `embedded_gaussian` and `dot_product`.
            Default: embedded_gaussian.
    """

    def __init__(self,
                 in_channels,
                 reduction=2,
                 use_scale=True,
                 conv_cfg=None,
                 norm_cfg=None,
                 mode='embedded_gaussian',
                 **kwargs):
        super(_NonLocalNd, self).__init__()
        self.in_channels = in_channels
        self.reduction = reduction
        self.use_scale = use_scale
        self.inter_channels = in_channels // reduction
        self.mode = mode

        if mode not in ['embedded_gaussian', 'dot_product']:
            raise ValueError(
                "Mode should be in 'embedded_gaussian' or 'dot_product', "
                f'but got {mode} instead.')

        # g, theta, phi are defaulted as `nn.ConvNd`.
        # Here we use ConvModule for potential usage.
        self.g = ConvModule(
            self.in_channels,
            self.inter_channels,
            kernel_size=1,
            conv_cfg=conv_cfg,
            act_cfg=None)
        self.theta = ConvModule(
            self.in_channels,
            self.inter_channels,
            kernel_size=1,
            conv_cfg=conv_cfg,
            act_cfg=None)
        self.phi = ConvModule(
            self.in_channels,
            self.inter_channels,
            kernel_size=1,
            conv_cfg=conv_cfg,
            act_cfg=None)
        self.conv_out = ConvModule(
            self.inter_channels,
            self.in_channels,
            kernel_size=1,
            conv_cfg=conv_cfg,
            norm_cfg=norm_cfg,
            act_cfg=None)

        self.init_weights(**kwargs)

    def init_weights(self, std=0.01, zeros_init=True):
        for m in [self.g, self.theta, self.phi]:
            normal_init(m.conv, std=std)
        if zeros_init:
            if self.conv_out.norm_cfg is None:
                constant_init(self.conv_out.conv, 0)
            else:
                constant_init(self.conv_out.norm, 0)
        else:
            if self.conv_out.norm_cfg is None:
                normal_init(self.conv_out.conv, std=std)
            else:
                normal_init(self.conv_out.norm, std=std)

    def embedded_gaussian(self, theta_x, phi_x):
        # NonLocal1d pairwise_weight: [N, H, H]
        # NonLocal2d pairwise_weight: [N, HxW, HxW]
        # NonLocal3d pairwise_weight: [N, TxHxW, TxHxW]
        pairwise_weight = torch.matmul(theta_x, phi_x)
        if self.use_scale:
            # theta_x.shape[-1] is `self.inter_channels`
            pairwise_weight /= theta_x.shape[-1]**0.5
        pairwise_weight = pairwise_weight.softmax(dim=-1)
        return pairwise_weight

    def dot_product(self, theta_x, phi_x):
        # NonLocal1d pairwise_weight: [N, H, H]
        # NonLocal2d pairwise_weight: [N, HxW, HxW]
        # NonLocal3d pairwise_weight: [N, TxHxW, TxHxW]
        pairwise_weight = torch.matmul(theta_x, phi_x)
        pairwise_weight /= pairwise_weight.shape[-1]
        return pairwise_weight

    def forward(self, x):
        # Assume `reduction = 1`, then `inter_channels = C`
        # NonLocal1d x: [N, C, H]
        # NonLocal2d x: [N, C, H, W]
        # NonLocal3d x: [N, C, T, H, W]
        n = x.size(0)

        # NonLocal1d g_x: [N, H, C]
        # NonLocal2d g_x: [N, HxW, C]
        # NonLocal3d g_x: [N, TxHxW, C]
        g_x = self.g(x).view(n, self.inter_channels, -1)
        g_x = g_x.permute(0, 2, 1)

        # NonLocal1d theta_x: [N, H, C]
        # NonLocal2d theta_x: [N, HxW, C]
        # NonLocal3d theta_x: [N, TxHxW, C]
        theta_x = self.theta(x).view(n, self.inter_channels, -1)
        theta_x = theta_x.permute(0, 2, 1)

        # NonLocal1d phi_x: [N, C, H]
        # NonLocal2d phi_x: [N, C, HxW]
        # NonLocal3d phi_x: [N, C, TxHxW]
        phi_x = self.phi(x).view(n, self.inter_channels, -1)

        pairwise_func = getattr(self, self.mode)
        # NonLocal1d pairwise_weight: [N, H, H]
        # NonLocal2d pairwise_weight: [N, HxW, HxW]
        # NonLocal3d pairwise_weight: [N, TxHxW, TxHxW]
        pairwise_weight = pairwise_func(theta_x, phi_x)

        # NonLocal1d y: [N, H, C]
        # NonLocal2d y: [N, HxW, C]
        # NonLocal3d y: [N, TxHxW, C]
        y = torch.matmul(pairwise_weight, g_x)
        # NonLocal1d y: [N, C, H]
        # NonLocal2d y: [N, C, H, W]
        # NonLocal3d y: [N, C, T, H, W]
        y = y.permute(0, 2, 1).contiguous().reshape(n, self.inter_channels,
                                                    *x.size()[2:])

        output = x + self.conv_out(y)

        return output


[docs]class NonLocal1d(_NonLocalNd): """1D Non-local module. Args: in_channels (int): Same as `NonLocalND`. sub_sample (bool): Whether to apply max pooling after pairwise function (Note that the `sub_sample` is applied on spatial only). Default: False. conv_cfg (None | dict): Same as `NonLocalND`. Default: dict(type='Conv1d'). """ def __init__(self, in_channels, sub_sample=False, conv_cfg=dict(type='Conv1d'), **kwargs): super(NonLocal1d, self).__init__( in_channels, conv_cfg=conv_cfg, **kwargs) self.sub_sample = sub_sample if sub_sample: max_pool_layer = nn.MaxPool1d(kernel_size=2) self.g = nn.Sequential(self.g, max_pool_layer) self.phi = nn.Sequential(self.phi, max_pool_layer)
[docs]@PLUGIN_LAYERS.register_module() class NonLocal2d(_NonLocalNd): """2D Non-local module. Args: in_channels (int): Same as `NonLocalND`. sub_sample (bool): Whether to apply max pooling after pairwise function (Note that the `sub_sample` is applied on spatial only). Default: False. conv_cfg (None | dict): Same as `NonLocalND`. Default: dict(type='Conv2d'). """ _abbr_ = 'nonlocal_block' def __init__(self, in_channels, sub_sample=False, conv_cfg=dict(type='Conv2d'), **kwargs): super(NonLocal2d, self).__init__( in_channels, conv_cfg=conv_cfg, **kwargs) self.sub_sample = sub_sample if sub_sample: max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2)) self.g = nn.Sequential(self.g, max_pool_layer) self.phi = nn.Sequential(self.phi, max_pool_layer)
[docs]class NonLocal3d(_NonLocalNd): """3D Non-local module. Args: in_channels (int): Same as `NonLocalND`. sub_sample (bool): Whether to apply max pooling after pairwise function (Note that the `sub_sample` is applied on spatial only). Default: False. conv_cfg (None | dict): Same as `NonLocalND`. Default: dict(type='Conv3d'). """ def __init__(self, in_channels, sub_sample=False, conv_cfg=dict(type='Conv3d'), **kwargs): super(NonLocal3d, self).__init__( in_channels, conv_cfg=conv_cfg, **kwargs) self.sub_sample = sub_sample if sub_sample: max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2)) self.g = nn.Sequential(self.g, max_pool_layer) self.phi = nn.Sequential(self.phi, max_pool_layer)