""" Attention Layers
The following layers implement an attention gating module based on the
paper "Attention U-Net: Learning Where to Look For the Pancreas"
(Oktay et al.). The code below is based on a PyTorch implementation of
this technique by the paper's authors:
https://github.com/ozan-oktay/Attention-Gated-Networks/tree/a96edb72622274f6705097d70cfaa7f2bf818a5a
Each layer has 2D and 3D versions. Only the 2D versions have been tested
so far.
"""
from typing import Dict, Sequence, Union
import numpy as np
from keras import backend as K
from keras.layers import BatchNormalization as BN
from keras.layers import Concatenate, Conv2D, Conv3D, Layer, Multiply, UpSampling2D, UpSampling3D
class _CreateGatingSignalNDim(Layer):
def __init__(
self,
dimension: int,
out_channels: int,
kernel_size: Union[int, Sequence[int]],
kernel_initializer: Union[str, Dict],
activation: str,
add_batchnorm: bool,
**kwargs
):
"""
This layer creates the first gating signal for attention based on
a feature map. This feature map should contain contextual information
that will help the network focus on important regions in the image.
In the paper, the feature map chosen for a U-Net is the coarsest
feature map at end of the encoding arch.
The layer performs a simple operation to transform the feature map
into the first gating signal:
Convolution --> Activation --> BatchNorm
Args:
dimension: the dimension of the model's input images
out_channels: the number of channels for the gating signal
kernel_size: the kernel size used in the convolutional layer
kernel_initializer: the method for initializing the weights of the
convolutional layer
activation: the activation function used after the convolutional
layer
add_batchnorm: specifies if batch normalization should be used
"""
super(_CreateGatingSignalNDim, self).__init__(**kwargs)
# Store parameters
self.dimension = dimension
self.out_channels = out_channels
self.kernel_size = kernel_size
self.kernel_initializer = kernel_initializer
self.activation = activation
self.add_batchnorm = add_batchnorm
if self.dimension == 2:
conv_type = Conv2D
elif self.dimension == 3:
conv_type = Conv3D
else:
raise ValueError("Only 2D and 3D are supported")
if isinstance(self.kernel_size, tuple) or isinstance(self.kernel_size, list):
assert len(self.kernel_size) == self.dimension, (
"If list/tuple, kernel_size must have length %d" % self.dimension
)
self.conv = conv_type(
self.out_channels,
self.kernel_size,
padding="same",
activation=self.activation,
kernel_initializer=self.kernel_initializer,
)
if self.add_batchnorm:
self.bn = BN(axis=-1, momentum=0.95, epsilon=0.001)
def build(self, input_shape):
self.conv.build(input_shape)
self._trainable_weights = self.conv.trainable_weights
conv_output_shape = self.conv.compute_output_shape(input_shape)
if self.add_batchnorm:
self.bn.build(conv_output_shape)
self._trainable_weights += self.bn.trainable_weights
super(_CreateGatingSignalNDim, self).build(input_shape)
def call(self, inputs):
outputs = self.conv(inputs)
if self.add_batchnorm:
outputs = self.bn(outputs)
return outputs
def compute_output_shape(self, input_shape):
conv_output_shape = self.conv.compute_output_shape(input_shape)
final_output_shape = conv_output_shape
if self.add_batchnorm:
bn_output_shape = self.bn.compute_output_shape(conv_output_shape)
final_output_shape = bn_output_shape
return final_output_shape
def get_config(self):
base_cfg = super().get_config()
base_cfg.update(
{
"dimension": self.dimension,
"out_channels": self.out_channels,
"kernel_size": self.kernel_size,
"kernel_initializer": self.kernel_initializer,
"activation": self.activation,
"add_batchnorm": self.add_batchnorm,
}
)
return base_cfg
[docs]class CreateGatingSignal2D(_CreateGatingSignalNDim):
def __init__(
self,
out_channels: int,
kernel_size: Union[int, Sequence[int]] = 1,
kernel_initializer: Union[str, Dict] = "he_normal",
activation: str = "relu",
add_batchnorm: bool = True,
**kwargs
):
super(CreateGatingSignal2D, self).__init__(
dimension=2,
out_channels=out_channels,
kernel_size=kernel_size,
kernel_initializer=kernel_initializer,
activation=activation,
add_batchnorm=add_batchnorm,
**kwargs
)
[docs]class CreateGatingSignal3D(_CreateGatingSignalNDim):
def __init__(
self,
out_channels: int,
kernel_size: Union[int, Sequence[int]] = 1,
kernel_initializer: Union[str, Dict] = "he_normal",
activation: str = "relu",
add_batchnorm: bool = True,
**kwargs
):
super(CreateGatingSignal3D, self).__init__(
dimension=3,
out_channels=out_channels,
kernel_size=kernel_size,
kernel_initializer=kernel_initializer,
activation=activation,
add_batchnorm=add_batchnorm,
**kwargs
)
class _GridAttentionModuleND(Layer):
def __init__(
self,
dimension: int,
in_channels: int,
intermediate_channels: int,
sub_sample_factor: Union[int, Sequence[int]],
kernel_initializer: Union[str, Dict],
**kwargs
):
"""
This layer implements the additive attention gate proposed
in the paper and displayed in Figure 2 of the paper. The gate
takes in a feature map and a gating signal, calculates
the attention coefficients (ranging from 0 to 1), and multiplies
the feature map and the attention coefficients to down weight
unimportant feature vectors, based on contextual information
from the gating signal. The pruned feature map is then
linearly transformed using a convolutional layer, which is
followed by a batch normalization layer.
The formulas for computing the attention coefficients are
Equations 1 and 2 in the paper. The code below uses the variables
"theta_x" and "theta_gating" to represent W_x and W_g (and b_g)
found in Equations 1 and 2. Sigma_1 in Equation 1 is fixed
as the ReLU activation function and Sigma_2 in Equation 2 is
fixed as the sigmoid activation function.
Args:
dimension: the dimension of the model's input images
in_channels: the number of channels in the input feature map
intermediate_channels: F_int (in Figure 2 of the paper)
sub_sample_factor: the factor by which the input feature map
should be downsampled. This should be chosen
such that the input feature map is downsampled
to the resolution of the gating signal,
as described in the paper under the section
"Attention Gates in U-Net Model".
kernel_initializer: the method used for initializing the weights
of convolutional layers
"""
super(_GridAttentionModuleND, self).__init__(**kwargs)
# Store parameters
self.dimension = dimension
self.in_channels = in_channels
self.intermediate_channels = intermediate_channels
self.sub_sample_factor = sub_sample_factor
self.kernel_initializer = kernel_initializer
if self.dimension == 2:
self.conv_type = Conv2D
self.upsample_type = UpSampling2D
elif self.dimension == 3:
self.conv_type = Conv3D
# NOTE: The authors of the paper state they use trilinear
# interpolation for upsampling 3D images. However, there
# does not exist a trilinear interpolation
# mode for UpSampling3D.
self.upsample_type = UpSampling3D
else:
raise ValueError("Only 2D and 3D are supported")
if isinstance(self.sub_sample_factor, tuple) or isinstance(self.sub_sample_factor, list):
assert len(self.sub_sample_factor) == self.dimension, (
"If list/tuple, sub_sample_factor must have length %d" % self.dimension
)
self.theta_x = self.conv_type(
self.intermediate_channels,
kernel_size=self.sub_sample_factor,
strides=self.sub_sample_factor,
use_bias=False,
kernel_initializer=self.kernel_initializer,
)
self.theta_gating = self.conv_type(
self.intermediate_channels, kernel_size=1, kernel_initializer=self.kernel_initializer
)
self.psi = self.conv_type(1, kernel_size=1, kernel_initializer=self.kernel_initializer)
self.output_conv = self.conv_type(
self.in_channels, kernel_size=1, kernel_initializer=self.kernel_initializer
)
self.output_bn = BN(axis=-1, momentum=0.95, epsilon=0.001)
def build(self, input_shape):
x_shape, gating_signal_shape = input_shape
# Build theta_x
self.theta_x.build(x_shape)
self._trainable_weights = self.theta_x.trainable_weights
theta_x_output_shape = self.theta_x.compute_output_shape(x_shape)
# Build theta_gating
self.theta_gating.build(gating_signal_shape)
self._trainable_weights += self.theta_gating.trainable_weights
theta_gating_output_shape = self.theta_gating.compute_output_shape(gating_signal_shape)
# Build upsample_gating
up_ratio_gating = np.floor_divide(
theta_x_output_shape[1:-1], theta_gating_output_shape[1:-1]
)
self.upsample_gating = self.upsample_type(size=tuple(up_ratio_gating))
self.upsample_gating.build(theta_gating_output_shape)
self._trainable_weights += self.upsample_gating.trainable_weights
up_gating_output_shape = self.upsample_gating.compute_output_shape(
theta_gating_output_shape
)
assert (
up_gating_output_shape[1:-1] == theta_x_output_shape[1:-1]
), "Cannot upsample output of theta_gating to match size of output of theta_x"
# Build psi
self.psi.build(theta_x_output_shape)
self._trainable_weights += self.psi.trainable_weights
psi_output_shape = self.psi.compute_output_shape(theta_x_output_shape)
# Build upsample_attn_coeff
up_ratio_attn_coeff = np.floor_divide(x_shape[1:-1], psi_output_shape[1:-1])
self.upsample_attn_coeff = self.upsample_type(size=tuple(up_ratio_attn_coeff))
self.upsample_attn_coeff.build(psi_output_shape)
self._trainable_weights += self.upsample_attn_coeff.trainable_weights
up_coeff_output_shape = self.upsample_attn_coeff.compute_output_shape(psi_output_shape)
assert (
up_coeff_output_shape[1:-1] == x_shape[1:-1]
), "Cannot upsample output of psi to match size of input feature map (x)"
# Build output_conv
self.output_conv.build(x_shape)
self._trainable_weights += self.output_conv.trainable_weights
output_conv_output_shape = self.output_conv.compute_output_shape(x_shape)
# Build output_bn
self.output_bn.build(output_conv_output_shape)
self._trainable_weights += self.output_bn.trainable_weights
super(_GridAttentionModuleND, self).build(input_shape)
def call(self, inputs):
x, gating_signal = inputs
theta_x_out = self.theta_x(x)
theta_gating_out = self.theta_gating(gating_signal)
# If theta_gating_out is smaller than theta_x_out,
# then upsample using UpSampling2D or UpSampling3D. This should
# not be needed if the right value is chosen for
# sub_sample_factor.
up_sampled_gating = self.upsample_gating(theta_gating_out)
psi_out = self.psi(K.relu(theta_x_out + up_sampled_gating))
sigmoid_psi = K.sigmoid(psi_out)
x_size = K.int_shape(x)
# Need to upsample to size of the input feature map (x), such that
# the attention coefficients can be multiplied with the input
# feature map
up_sampled_attn_coeff = self.upsample_attn_coeff(sigmoid_psi)
attn_weighted_output = Multiply()(
[K.repeat_elements(up_sampled_attn_coeff, rep=x_size[-1], axis=-1), x]
)
output = self.output_conv(attn_weighted_output)
output = self.output_bn(output)
return [output, up_sampled_attn_coeff]
def compute_output_shape(self, input_shape):
x_shape, gating_signal_shape = input_shape
theta_gating_output_shape = self.theta_gating.compute_output_shape(gating_signal_shape)
up_gating_output_shape = self.upsample_gating.compute_output_shape(
theta_gating_output_shape
)
psi_output_shape = self.psi.compute_output_shape(up_gating_output_shape)
up_attn_output_shape = self.upsample_attn_coeff.compute_output_shape(psi_output_shape)
output_conv_output_shape = self.output_conv.compute_output_shape(x_shape)
output_bn_output_shape = self.output_bn.compute_output_shape(output_conv_output_shape)
return [output_bn_output_shape, up_attn_output_shape]
def get_config(self):
base_cfg = super().get_config()
base_cfg.update(
{
"dimension": self.dimension,
"in_channels": self.in_channels,
"intermediate_channels": self.intermediate_channels,
"sub_sample_factor": self.sub_sample_factor,
"kernel_initializer": self.kernel_initializer,
}
)
return base_cfg
[docs]class GridAttentionModule2D(_GridAttentionModuleND):
def __init__(
self,
in_channels: int,
intermediate_channels: int,
sub_sample_factor: Union[int, Sequence[int]] = 2,
kernel_initializer: Union[str, Dict] = "he_normal",
**kwargs
):
super(GridAttentionModule2D, self).__init__(
dimension=2,
in_channels=in_channels,
intermediate_channels=intermediate_channels,
sub_sample_factor=sub_sample_factor,
kernel_initializer=kernel_initializer,
**kwargs
)
[docs]class GridAttentionModule3D(_GridAttentionModuleND):
def __init__(
self,
in_channels: int,
intermediate_channels: int,
sub_sample_factor: Union[int, Sequence[int]] = 2,
kernel_initializer: Union[str, Dict] = "he_normal",
**kwargs
):
super(GridAttentionModule3D, self).__init__(
dimension=3,
in_channels=in_channels,
intermediate_channels=intermediate_channels,
sub_sample_factor=sub_sample_factor,
kernel_initializer=kernel_initializer,
**kwargs
)
class _MultiAttentionModuleND(Layer):
def __init__(
self,
dimension: int,
in_channels: int,
intermediate_channels: int,
sub_sample_factor: Union[int, Sequence[int]],
kernel_initializer: Union[str, Dict],
activation: str,
**kwargs
):
"""
This layer combines the outputs of two attention gates that
each receive the same inputs. The outputs are concatenated
along the channel axis and passed through a convolutional layer
and a batch normalization layer. The operation is as follows:
Concatenated Gate Outputs --> Convolution --> Activation --> BatchNorm
Due to the concatenation operation, the layer learns
multi-dimensional attention coefficients, as described in the
paper under the section "Attention Gates for Image Analysis".
After concatenation, the total attention coefficients are the
coefficients for each gate, concatenated along the channel
dimension. The "multi-dimensional" aspect comes from how
each attention coefficient is now a vector with two elements.
The number of attention gates (2) was chosen because the authors
of the paper used the same number in their PyTorch implementation
(GitHub URL included above).
Args:
activation: the activation function after the convolutional layer
all other parameters: same as in _GridAttentionModuleND
"""
super(_MultiAttentionModuleND, self).__init__(**kwargs)
# Store parameters
self.dimension = dimension
self.in_channels = in_channels
self.intermediate_channels = intermediate_channels
self.sub_sample_factor = sub_sample_factor
self.kernel_initializer = kernel_initializer
self.activation = activation
if self.dimension == 2:
self.conv_type = Conv2D
self.attn_module_type = GridAttentionModule2D
elif self.dimension == 3:
self.conv_type = Conv3D
self.attn_module_type = GridAttentionModule3D
else:
raise ValueError("Only 2D and 3D are supported")
self.attn_gate_1 = self.attn_module_type(
in_channels=self.in_channels,
intermediate_channels=self.intermediate_channels,
sub_sample_factor=self.sub_sample_factor,
kernel_initializer=self.kernel_initializer,
)
self.attn_gate_2 = self.attn_module_type(
in_channels=self.in_channels,
intermediate_channels=self.intermediate_channels,
sub_sample_factor=self.sub_sample_factor,
kernel_initializer=self.kernel_initializer,
)
self.combine_gates_conv = self.conv_type(
self.in_channels,
kernel_size=1,
activation=self.activation,
kernel_initializer=self.kernel_initializer,
)
self.combine_gates_bn = BN(axis=-1, momentum=0.95, epsilon=0.001)
def build(self, input_shape):
self.attn_gate_1.build(input_shape)
self._trainable_weights = self.attn_gate_1.trainable_weights
self.attn_gate_2.build(input_shape)
self._trainable_weights += self.attn_gate_2.trainable_weights
output_attn_shape, _ = self.attn_gate_2.compute_output_shape(input_shape)
concatenate_gate_shape = list(output_attn_shape)
concatenate_gate_shape[-1] *= 2
self.combine_gates_conv.build(concatenate_gate_shape)
self._trainable_weights += self.combine_gates_conv.trainable_weights
conv_output_shape = self.combine_gates_conv.compute_output_shape(concatenate_gate_shape)
self.combine_gates_bn.build(conv_output_shape)
self._trainable_weights += self.combine_gates_bn.trainable_weights
super(_MultiAttentionModuleND, self).build(input_shape)
def call(self, inputs):
x, gating_signal = inputs
gate_1_out, attn_coeff_1 = self.attn_gate_1([x, gating_signal])
gate_2_out, attn_coeff_2 = self.attn_gate_2([x, gating_signal])
total_gate_outputs = Concatenate(axis=-1)([gate_1_out, gate_2_out])
total_attn_coeffs = Concatenate(axis=-1)([attn_coeff_1, attn_coeff_2])
output = self.combine_gates_conv(total_gate_outputs)
output = self.combine_gates_bn(output)
return [output, total_attn_coeffs]
def compute_output_shape(self, input_shape):
output_attn_shape, coeff_shape = self.attn_gate_2.compute_output_shape(input_shape)
concatenate_gate_shape = list(output_attn_shape)
concatenate_coeff_shape = list(coeff_shape)
concatenate_gate_shape[-1] *= 2
concatenate_coeff_shape[-1] *= 2
conv_output_shape = self.combine_gates_conv.compute_output_shape(concatenate_gate_shape)
bn_output_shape = self.combine_gates_bn.compute_output_shape(conv_output_shape)
return [bn_output_shape, concatenate_coeff_shape]
def get_config(self):
base_cfg = super().get_config()
base_cfg.update(
{
"dimension": self.dimension,
"in_channels": self.in_channels,
"intermediate_channels": self.intermediate_channels,
"sub_sample_factor": self.sub_sample_factor,
"kernel_initializer": self.kernel_initializer,
"activation": self.activation,
}
)
return base_cfg
[docs]class MultiAttentionModule2D(_MultiAttentionModuleND):
def __init__(
self,
in_channels: int,
intermediate_channels: int,
sub_sample_factor: Union[int, Sequence[int]] = 2,
kernel_initializer: Union[str, Dict] = "he_normal",
activation: str = "relu",
**kwargs
):
super(MultiAttentionModule2D, self).__init__(
dimension=2,
in_channels=in_channels,
intermediate_channels=intermediate_channels,
sub_sample_factor=sub_sample_factor,
kernel_initializer=kernel_initializer,
activation=activation,
**kwargs
)
[docs]class MultiAttentionModule3D(_MultiAttentionModuleND):
def __init__(
self,
in_channels: int,
intermediate_channels: int,
sub_sample_factor: Union[int, Sequence[int]] = 2,
kernel_initializer: Union[str, Dict] = "he_normal",
activation: str = "relu",
**kwargs
):
super(MultiAttentionModule3D, self).__init__(
dimension=3,
in_channels=in_channels,
intermediate_channels=intermediate_channels,
sub_sample_factor=sub_sample_factor,
kernel_initializer=kernel_initializer,
activation=activation,
**kwargs
)
class _DeepSupervisionND(Layer):
def __init__(
self,
dimension: int,
out_channels: int,
scale_factor: Union[int, Sequence[int]],
kernel_initializer: Union[str, Dict],
**kwargs
):
"""
This layer implements deep-supervision, as mentioned in the paper
under the section "Attention Gates in U-Net Model". The
structure of the implementation follows the same structure as
the author's PyTorch implementation (GitHub URL included above).
The layer takes a feature map as input, passes it through a
convolutional layer, and upsamples the output of convolution
by the user-inputted scale factor.
In a U-Net, deep-supervision is used to transform the outputs of
all levels in the decoding arch to have a number of channels equal
to the number of classes for segmentation, and upsamples these
outputs to have the same resolution as the input image. These
transformed outputs are then concatenated along the channel dimension.
The concatenated outputs are passed through the final step of
the U-Net to obtain probabilities for segmentation.
Args:
dimension: the dimension of the model's input images
out_channels: the number of channels for the output of the
convolutional layer. For a U-Net, this will
be the number of classes.
scale_factor: the factor by which the input feature map is
upsampled
kernel_initializer: the method for initializing the weights of the
convolutional layer
"""
super(_DeepSupervisionND, self).__init__(**kwargs)
# Store parameters
self.dimension = dimension
self.out_channels = out_channels
self.scale_factor = scale_factor
self.kernel_initializer = kernel_initializer
assert isinstance(self.scale_factor, list) or isinstance(
self.scale_factor, tuple
), "scale_factor must be a list or tuple"
if isinstance(self.scale_factor, list):
self.scale_factor = tuple(self.scale_factor)
if self.dimension == 2:
self.conv_type = Conv2D
self.upsample = UpSampling2D(size=self.scale_factor)
elif self.dimension == 3:
self.conv_type = Conv3D
self.upsample = UpSampling3D(size=self.scale_factor)
else:
raise ValueError("Only 2D and 3D are supported")
self.conv = self.conv_type(
self.out_channels, kernel_size=1, kernel_initializer=self.kernel_initializer
)
def build(self, input_shape):
self.conv.build(input_shape)
self._trainable_weights = self.conv.trainable_weights
conv_output_shape = self.conv.compute_output_shape(input_shape)
self.upsample.build(conv_output_shape)
self._trainable_weights += self.upsample.trainable_weights
super(_DeepSupervisionND, self).build(input_shape)
def call(self, inputs):
outputs = self.conv(inputs)
outputs = self.upsample(outputs)
return outputs
def compute_output_shape(self, input_shape):
conv_output_shape = self.conv.compute_output_shape(input_shape)
upsample_output_shape = self.upsample.compute_output_shape(conv_output_shape)
return upsample_output_shape
def get_config(self):
base_cfg = super().get_config()
base_cfg.update(
{
"dimension": self.dimension,
"out_channels": self.out_channels,
"scale_factor": self.scale_factor,
"kernel_initializer": self.kernel_initializer,
}
)
return base_cfg
[docs]class DeepSupervision2D(_DeepSupervisionND):
def __init__(
self,
out_channels: int,
scale_factor: Union[int, Sequence[int]],
kernel_initializer: Union[str, Dict] = "he_normal",
**kwargs
):
super(DeepSupervision2D, self).__init__(
dimension=2,
out_channels=out_channels,
scale_factor=scale_factor,
kernel_initializer=kernel_initializer,
**kwargs
)
[docs]class DeepSupervision3D(_DeepSupervisionND):
def __init__(
self,
out_channels: int,
scale_factor: Union[int, Sequence[int]],
kernel_initializer: Union[str, Dict] = "he_normal",
**kwargs
):
super(DeepSupervision3D, self).__init__(
dimension=3,
out_channels=out_channels,
scale_factor=scale_factor,
kernel_initializer=kernel_initializer,
**kwargs
)