Source code for medsegpy.modeling.meta_arch.unet

from typing import Dict, Sequence, Union

import numpy as np
import tensorflow as tf
from keras import Input
from keras.engine import get_source_inputs
from keras.layers import BatchNormalization as BN
from keras.layers import (
    Concatenate,
    Conv2D,
    Conv2DTranspose,
    Conv3D,
    Conv3DTranspose,
    Dropout,
    MaxPooling2D,
    MaxPooling3D,
)

from medsegpy.config import UNet3DConfig, UNetConfig
from medsegpy.modeling.layers.attention import (
    CreateGatingSignal2D,
    CreateGatingSignal3D,
    DeepSupervision2D,
    DeepSupervision3D,
    MultiAttentionModule2D,
    MultiAttentionModule3D,
)
from medsegpy.modeling.model_utils import add_sem_seg_activation, zero_pad_like

from ..model import Model
from .build import META_ARCH_REGISTRY, ModelBuilder


[docs]def build_encoder_block( x: tf.Tensor, num_filters: Union[int, Sequence[int]], kernel_size: Union[int, Sequence[int]] = 3, num_conv: int = 2, activation: str = "relu", kernel_initializer: Union[str, Dict] = "he_normal", dropout: float = 0.0, ): """Builds simple FCN encoder block. Structure is below. Where blocks in `[]` are repeated: [Conv -> Activation] -> BN -> Dropout. Args: x (tf.Tensor): Input tensor. num_filters (int or Sequence[int]): Number of filters to use for each conv layer. If a sequence, will override number of conv layers specified by `num_conv`. kernel_size: Kernel size accepted by Keras convolution layers. num_conv (int, optional): Number of convolutional blocks (conv + activation) to use. activation (str, optional): Activation type. kernel_initializer: Kernel initializer accepted by `keras.layers.Conv(...)`. dropout (float, optional): Dropout rate. Returns: tf.Tensor: Encoder block output. """ if len(x.shape) == 4: # 2D conv_type = Conv2D elif len(x.shape) == 5: # 3D conv_type = Conv3D else: raise ValueError("Only 2D or 3D inputs are supported") if isinstance(num_filters, int): num_filters = [num_filters] * num_conv for filters in num_filters: x = conv_type( filters, kernel_size, padding="same", activation=activation, kernel_initializer=kernel_initializer, )(x) x = BN(axis=-1, momentum=0.95, epsilon=0.001)(x) x = Dropout(rate=dropout)(x)
return x def build_decoder_block( x, x_skip, num_filters: Union[int, Sequence[int]], unpool_size: Sequence[int], kernel_size: Union[int, Sequence[int]] = 3, num_conv: int = 2, activation: str = "relu", kernel_initializer: Union[str, Dict] = "he_normal", dropout: float = 0.0, ): if len(x.shape) == 4: # 2D conv_transpose_type = Conv2DTranspose elif len(x.shape) == 5: # 3D conv_transpose_type = Conv3DTranspose else: raise ValueError("Only 2D or 3D inputs are supported") if isinstance(num_filters, int): # conv_transpose + num_conv num_filters = [num_filters] * (num_conv + 1) x_shape = x.shape.as_list()[1:-1] x = conv_transpose_type( num_filters[0], kernel_size, padding="same", strides=unpool_size, kernel_initializer=kernel_initializer, )(x) x_shape = [x * y for x, y in zip(x_shape, unpool_size)] x = Concatenate(axis=-1)([zero_pad_like(x, x_skip, x_shape), x_skip]) if len(num_filters) > 1: x = build_encoder_block( x, num_filters[1:], kernel_size, num_conv, activation, kernel_initializer, dropout ) return x @META_ARCH_REGISTRY.register() class UNet2D(ModelBuilder): def __init__(self, cfg: UNetConfig): super().__init__(cfg) self._pooler_type = MaxPooling2D self._conv_type = Conv2D self._multi_attention_module = MultiAttentionModule2D self._create_gating_signal = CreateGatingSignal2D self._deep_supervision = DeepSupervision2D self._dim = 2 self.kernel_size = (3, 3) self.use_attention = False self.use_deep_supervision = False def build_model(self, input_tensor=None) -> Model: cfg = self._cfg seed = cfg.SEED depth = cfg.DEPTH kernel_size = self.kernel_size self.use_attention = cfg.USE_ATTENTION self.use_deep_supervision = cfg.USE_DEEP_SUPERVISION kernel_initializer = {"class_name": cfg.KERNEL_INITIALIZER, "config": {"seed": seed}} output_mode = cfg.LOSS[1] assert output_mode in ["sigmoid", "softmax"] num_classes = cfg.get_num_classes() num_filters = cfg.NUM_FILTERS if not num_filters: num_filters = [2 ** feat * 32 for feat in range(depth)] else: depth = len(num_filters) # Build inputs. if not input_tensor: inputs = self._build_input(cfg.IMG_SIZE) else: inputs = input_tensor # Encoder. x_skips = [] pool_sizes = [] x = inputs deep_supervision_outputs = [] scale_factors = np.ones(((depth - 1), self._dim), dtype=int) for depth_cnt in range(depth): x = build_encoder_block( x, num_filters[depth_cnt], kernel_size=kernel_size, num_conv=2, activation="relu", kernel_initializer=kernel_initializer, dropout=0.0, ) # Maxpool until penultimate depth. if depth_cnt < depth - 1: x_skips.append(x) pool_size = self._get_pool_size(x) pool_sizes.append(pool_size) if depth_cnt < depth - 2: scale_factors[depth_cnt + 1] = scale_factors[depth_cnt] * pool_size x = self._pooler_type(pool_size=pool_size)(x) # Decoder. for i, (x_skip, unpool_size) in enumerate(zip(x_skips[::-1], pool_sizes[::-1])): depth_cnt = depth - i - 2 skip_connect = x_skip # The first skip connection is not passed through # an attention gate, as mentioned in the paper under # the section "Attention Gates in U-Net Model" if depth_cnt > 0 and self.use_attention: if i == 0: gating_signal = self._create_gating_signal( out_channels=num_filters[depth_cnt + 1] )(x) else: gating_signal = x attn_out, attn_coeffs = self._multi_attention_module( in_channels=num_filters[depth_cnt], intermediate_channels=num_filters[depth_cnt], sub_sample_factor=self._get_pool_size(x_skip), )([x_skip, gating_signal]) skip_connect = attn_out x = build_decoder_block( x, skip_connect, num_filters[depth_cnt], unpool_size, kernel_size=kernel_size, num_conv=2, activation="relu", kernel_initializer=kernel_initializer, dropout=0.0, ) if self.use_deep_supervision: # Determine scale factor for deep supervision deep_supervision_outputs.append( self._deep_supervision( out_channels=num_classes, scale_factor=tuple(scale_factors[depth_cnt]) )(x) ) if self.use_deep_supervision: x = Concatenate(axis=-1)(deep_supervision_outputs) # 1x1 convolution to get pixel-wise semantic segmentation. x = add_sem_seg_activation( x, num_classes, conv_type=self._conv_type, activation=output_mode, kernel_initializer="glorot_uniform", seed=seed, ) if input_tensor is not None: inputs = get_source_inputs(input_tensor) model = Model(inputs=inputs, outputs=[x]) return model def _build_input(self, input_size): if len(input_size) == 2: input_size = input_size + (1,) elif len(input_size) != 3: raise ValueError("2D U-Net must have an input of shape HxWxC") return Input(input_size) def _get_pool_size(self, x: tf.Tensor): return [2 if d % 2 == 0 or d % 3 != 0 else 3 for d in x.shape.as_list()[1:-1]] @META_ARCH_REGISTRY.register() class UNet3D(UNet2D): def __init__(self, cfg: UNet3DConfig): super().__init__(cfg) self._pooler_type = MaxPooling3D self._conv_type = Conv3D self._multi_attention_module = MultiAttentionModule3D self._create_gating_signal = CreateGatingSignal3D self._deep_supervision = DeepSupervision3D self._dim = 3 self.kernel_size = (3, 3, 3) def _build_input(self, input_size): if len(input_size) == 3: input_size = input_size + (1,) elif len(input_size) != 4: raise ValueError("2D U-Net must have an input of shape HxWxDxC") return Input(input_size) def _get_pool_size(self, x: tf.Tensor): return (2, 2, 1) if x.shape.as_list()[-2] // 2 == 0 else (2, 2, 2)