Source code for medsegpy.modeling.meta_arch.deeplabv3

"""DeeplabV3+ implementation.

This model is based on TF and Keras repos below:
https://github.com/tensorflow/models/tree/master/research/deeplab
https://github.com/bonlime/keras-deeplab-v3-plus
"""

import logging

import numpy as np
from keras import backend as K
from keras import layers
from keras.engine.topology import get_source_inputs
from keras.layers import (
    Activation,
    Add,
    AveragePooling2D,
    BatchNormalization,
    Concatenate,
    Conv2D,
    DepthwiseConv2D,
    Dropout,
    Input,
    ZeroPadding2D,
)

from medsegpy.config import DeeplabV3Config

from .. import model_utils as utils
from ..layers import BilinearUpsampling
from ..model import Model
from .build import META_ARCH_REGISTRY, ModelBuilder

logger = logging.getLogger(__name__)


[docs]@META_ARCH_REGISTRY.register() class DeeplabV3Plus(ModelBuilder): def __init__(self, cfg: DeeplabV3Config): super().__init__(cfg) self._seed = cfg.SEED def _get_seed(self): # We change the seed to avoid initializing all kernels with the same # seed, but still keeping the individual seeds deterministic. seed = self._seed if self._seed is not None: self._seed += 1 return seed def _kernel_initializer(self): kernel_initializer = utils.build_keras_config( self._cfg.KERNEL_INITIALIZER, seed=self._get_seed() ) return kernel_initializer
[docs] def sep_conv_bn( self, x, filters, prefix, stride=1, kernel_size=3, rate=1, depth_activation=False, epsilon=1e-3, ): """SepConv with BN between depthwise & pointwise. Optionally add activation after BN. Implements right "same" padding for even kernel sizes Args: x: input tensor filters: num of filters in pointwise convolution prefix: prefix before name stride: stride at depthwise conv kernel_size: kernel size for depthwise convolution rate: atrous rate for depthwise convolution depth_activation: flag to use activation between depthwise & pointwise convs epsilon: epsilon to use in BN layer """ if stride == 1: depth_padding = "same" else: kernel_size_effective = kernel_size + (kernel_size - 1) * (rate - 1) pad_total = kernel_size_effective - 1 pad_beg = pad_total // 2 pad_end = pad_total - pad_beg x = ZeroPadding2D((pad_beg, pad_end))(x) depth_padding = "valid" if not depth_activation: x = Activation("relu")(x) x = DepthwiseConv2D( (kernel_size, kernel_size), strides=(stride, stride), dilation_rate=(rate, rate), padding=depth_padding, use_bias=False, kernel_initializer=self._kernel_initializer(), name=prefix + "_depthwise", )(x) x = BatchNormalization(name=prefix + "_depthwise_BN", epsilon=epsilon)(x) if depth_activation: x = Activation("relu")(x) x = Conv2D( filters, (1, 1), padding="same", use_bias=False, kernel_initializer=self._kernel_initializer(), name=prefix + "_pointwise", )(x) x = BatchNormalization(name=prefix + "_pointwise_BN", epsilon=epsilon)(x) if depth_activation: x = Activation("relu")(x)
return x def _conv2d_same(self, x, filters, prefix, stride=1, kernel_size=3, rate=1): """Implements right 'same' padding for even kernel sizes Without this there is a 1 pixel drift when stride = 2 Args: x: input tensor filters: num of filters in pointwise convolution prefix: prefix before name stride: stride at depthwise conv kernel_size: kernel size for depthwise convolution rate: atrous rate for depthwise convolution """ if stride == 1: return Conv2D( filters, (kernel_size, kernel_size), strides=(stride, stride), padding="same", use_bias=False, dilation_rate=(rate, rate), kernel_initializer=self._kernel_initializer(), name=prefix, )(x) else: kernel_size_effective = kernel_size + (kernel_size - 1) * (rate - 1) pad_total = kernel_size_effective - 1 pad_beg = pad_total // 2 pad_end = pad_total - pad_beg x = ZeroPadding2D((pad_beg, pad_end))(x) return Conv2D( filters, (kernel_size, kernel_size), strides=(stride, stride), padding="valid", use_bias=False, dilation_rate=(rate, rate), kernel_initializer=self._kernel_initializer(), name=prefix, )(x) def _xception_block( self, inputs, depth_list, prefix, skip_connection_type, stride, rate=1, depth_activation=False, return_skip=False, ): """Basic building block of modified Xception network Args: inputs: input tensor depth_list: number of filters in each SepConv layer. len(depth_list) == 3 prefix: prefix before name skip_connection_type: one of {'conv','sum','none'} stride: stride at last depthwise conv rate: atrous rate for depthwise convolution depth_activation: flag to use activation between depthwise & pointwise convs return_skip: flag to return additional tensor after 2 SepConvs for decoder """ residual = inputs for i in range(3): residual = self.sep_conv_bn( residual, depth_list[i], prefix + "_separable_conv{}".format(i + 1), stride=stride if i == 2 else 1, rate=rate, depth_activation=depth_activation, ) if i == 1: skip = residual if skip_connection_type == "conv": shortcut = self._conv2d_same( inputs, depth_list[-1], prefix + "_shortcut", kernel_size=1, stride=stride ) shortcut = BatchNormalization(name=prefix + "_shortcut_BN")(shortcut) outputs = layers.add([residual, shortcut]) elif skip_connection_type == "sum": outputs = layers.add([residual, inputs]) elif skip_connection_type == "none": outputs = residual if return_skip: return outputs, skip else: return outputs def _make_divisible(self, v, divisor, min_value=None): if min_value is None: min_value = divisor new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) # Make sure that round down does not go down by more than 10%. if new_v < 0.9 * v: new_v += divisor return new_v def _inverted_res_block( self, inputs, expansion, stride, alpha, filters, block_id, skip_connection, rate=1 ): in_channels = inputs._keras_shape[-1] pointwise_conv_filters = int(filters * alpha) pointwise_filters = self._make_divisible(pointwise_conv_filters, 8) x = inputs prefix = "expanded_conv_{}_".format(block_id) if block_id: # Expand x = Conv2D( expansion * in_channels, kernel_size=1, padding="same", use_bias=False, activation=None, kernel_initializer=self._kernel_initializer(), name=prefix + "expand", )(x) x = BatchNormalization(epsilon=1e-3, momentum=0.999, name=prefix + "expand_BN")(x) x = Activation(self.relu6, name=prefix + "expand_relu")(x) else: prefix = "expanded_conv_" # Depthwise x = DepthwiseConv2D( kernel_size=3, strides=stride, activation=None, use_bias=False, padding="same", dilation_rate=(rate, rate), kernel_initializer=self._kernel_initializer(), name=prefix + "depthwise", )(x) x = BatchNormalization(epsilon=1e-3, momentum=0.999, name=prefix + "depthwise_BN")(x) x = Activation(self.relu6, name=prefix + "depthwise_relu")(x) # Project x = Conv2D( pointwise_filters, kernel_size=1, padding="same", use_bias=False, activation=None, kernel_initializer=self._kernel_initializer(), name=prefix + "project", )(x) x = BatchNormalization(epsilon=1e-3, momentum=0.999, name=prefix + "project_BN")(x) if skip_connection: return Add(name=prefix + "add")([inputs, x]) # if in_channels == pointwise_filters and stride == 1: # return Add(name='res_connect_' + str(block_id))([inputs, x]) return x
[docs] def deeplabv3( self, input_tensor=None, input_shape=(512, 512, 3), classes=21, backbone="mobilenetv2", OS=16, alpha=1.0, dilation_divisor=1, dil_rate_input=None, dropout_rate=0.1, ) -> Model: """Instantiates the Deeplabv3+ architecture Optionally loads weights pre-trained on PASCAL VOC. This model is available for TensorFlow only, and can only be used with inputs following the TensorFlow data format `(width, height, channels)`. # Arguments weights: one of 'pascal_voc' (pre-trained on pascal voc) or None (random initialization) input_tensor: optional Keras tensor (i.e. output of `layers.Input()`) to use as image input for the model. input_shape: shape of input image. format HxWxC PASCAL VOC model was trained on (512,512,3) images classes: number of desired classes. If classes != 21, last layer is initialized randomly backbone: backbone to use. one of {'xception','mobilenetv2'} OS: determines input_shape/feature_extractor_output ratio. One of {8,16}. Used only for xception backbone. alpha: controls the width of the MobileNetV2 network. This is known as the width multiplier in the MobileNetV2 paper. - If `alpha` < 1.0, proportionally decreases the number of filters in each layer. - If `alpha` > 1.0, proportionally increases the number of filters in each layer. - If `alpha` = 1, default number of filters from the paper are used at each layer. Used only for mobilenetv2 backbone # Returns A Keras model instance. # Raises RuntimeError: If attempting to run this model with a backend that does not support separable convolutions. ValueError: in case of invalid argument for `weights` or `backbone` """ if K.backend() != "tensorflow": raise RuntimeError( "The Deeplabv3+ model is only available with " "the TensorFlow backend." ) if not (backbone in {"xception", "mobilenetv2"}): raise ValueError( "The `backbone` argument should be either " "`xception` or `mobilenetv2` " ) if input_tensor is None: img_input = Input(shape=input_shape) else: if not K.is_keras_tensor(input_tensor): img_input = Input(tensor=input_tensor, shape=input_shape) else: img_input = input_tensor if backbone == "xception": if OS == 8: entry_block3_stride = 1 middle_block_rate = 2 # ! Not mentioned in paper, but required exit_block_rates = (2, 4) atrous_rates = (12, 24, 36) else: entry_block3_stride = 2 middle_block_rate = 1 exit_block_rates = (1, 2) atrous_rates = ( 6 // dilation_divisor, 12 // dilation_divisor, 18 // dilation_divisor, ) if dil_rate_input is not None: atrous_rates = dil_rate_input x = Conv2D( 32, (3, 3), strides=(2, 2), name="entry_flow_conv1_1", use_bias=False, padding="same", kernel_initializer=self._kernel_initializer(), )(img_input) x = BatchNormalization(name="entry_flow_conv1_1_BN")(x) x = Activation("relu")(x) x = self._conv2d_same(x, 64, "entry_flow_conv1_2", kernel_size=3, stride=1) x = BatchNormalization(name="entry_flow_conv1_2_BN")(x) x = Activation("relu")(x) x = self._xception_block( x, [128, 128, 128], "entry_flow_block1", skip_connection_type="conv", stride=2, depth_activation=False, ) x, skip1 = self._xception_block( x, [256, 256, 256], "entry_flow_block2", skip_connection_type="conv", stride=2, depth_activation=False, return_skip=True, ) x = self._xception_block( x, [728, 728, 728], "entry_flow_block3", skip_connection_type="conv", stride=entry_block3_stride, depth_activation=False, ) for i in range(16): x = self._xception_block( x, [728, 728, 728], "middle_flow_unit_{}".format(i + 1), skip_connection_type="sum", stride=1, rate=middle_block_rate, depth_activation=False, ) x = self._xception_block( x, [728, 1024, 1024], "exit_flow_block1", skip_connection_type="conv", stride=1, rate=exit_block_rates[0], depth_activation=False, ) x = self._xception_block( x, [1536, 1536, 2048], "exit_flow_block2", skip_connection_type="none", stride=1, rate=exit_block_rates[1], depth_activation=True, ) else: OS = 8 first_block_filters = self._make_divisible(32 * alpha, 8) x = Conv2D( first_block_filters, kernel_size=3, strides=(2, 2), padding="same", use_bias=False, name="Conv", kernel_initializer=self._kernel_initializer(), )(img_input) x = BatchNormalization(epsilon=1e-3, momentum=0.999, name="Conv_BN")(x) x = Activation(self.relu6, name="Conv_Relu6")(x) x = self._inverted_res_block( x, filters=16, alpha=alpha, stride=1, expansion=1, block_id=0, skip_connection=False ) x = self._inverted_res_block( x, filters=24, alpha=alpha, stride=2, expansion=6, block_id=1, skip_connection=False ) x = self._inverted_res_block( x, filters=24, alpha=alpha, stride=1, expansion=6, block_id=2, skip_connection=True ) x = self._inverted_res_block( x, filters=32, alpha=alpha, stride=2, expansion=6, block_id=3, skip_connection=False ) x = self._inverted_res_block( x, filters=32, alpha=alpha, stride=1, expansion=6, block_id=4, skip_connection=True ) x = self._inverted_res_block( x, filters=32, alpha=alpha, stride=1, expansion=6, block_id=5, skip_connection=True ) # stride in block 6 changed from 2 -> 1, so we need to use rate = 2 x = self._inverted_res_block( x, filters=64, alpha=alpha, stride=1, # 1! expansion=6, block_id=6, skip_connection=False, ) x = self._inverted_res_block( x, filters=64, alpha=alpha, stride=1, rate=2, expansion=6, block_id=7, skip_connection=True, ) x = self._inverted_res_block( x, filters=64, alpha=alpha, stride=1, rate=2, expansion=6, block_id=8, skip_connection=True, ) x = self._inverted_res_block( x, filters=64, alpha=alpha, stride=1, rate=2, expansion=6, block_id=9, skip_connection=True, ) x = self._inverted_res_block( x, filters=96, alpha=alpha, stride=1, rate=2, expansion=6, block_id=10, skip_connection=False, ) x = self._inverted_res_block( x, filters=96, alpha=alpha, stride=1, rate=2, expansion=6, block_id=11, skip_connection=True, ) x = self._inverted_res_block( x, filters=96, alpha=alpha, stride=1, rate=2, expansion=6, block_id=12, skip_connection=True, ) x = self._inverted_res_block( x, filters=160, alpha=alpha, stride=1, rate=2, # 1! expansion=6, block_id=13, skip_connection=False, ) x = self._inverted_res_block( x, filters=160, alpha=alpha, stride=1, rate=4, expansion=6, block_id=14, skip_connection=True, ) x = self._inverted_res_block( x, filters=160, alpha=alpha, stride=1, rate=4, expansion=6, block_id=15, skip_connection=True, ) x = self._inverted_res_block( x, filters=320, alpha=alpha, stride=1, rate=4, expansion=6, block_id=16, skip_connection=False, ) # end of feature extractor # branching for Atrous Spatial Pyramid Pooling # Image Feature branch # out_shape = int(np.ceil(input_shape[0] / OS)) b4 = AveragePooling2D( pool_size=(int(np.ceil(input_shape[0] / OS)), int(np.ceil(input_shape[1] / OS))) )(x) b4 = Conv2D( 256, (1, 1), padding="same", use_bias=False, kernel_initializer=self._kernel_initializer(), name="image_pooling", )(b4) b4 = BatchNormalization(name="image_pooling_BN", epsilon=1e-5)(b4) b4 = Activation("relu")(b4) b4 = BilinearUpsampling( (int(np.ceil(input_shape[0] / OS)), int(np.ceil(input_shape[1] / OS))) )(b4) # simple 1x1 b0 = Conv2D( 256, (1, 1), padding="same", use_bias=False, kernel_initializer=self._kernel_initializer(), name="aspp0", )(x) b0 = BatchNormalization(name="aspp0_BN", epsilon=1e-5)(b0) b0 = Activation("relu", name="aspp0_activation")(b0) # there are only 2 branches in mobilenetV2. not sure why if backbone == "xception": # rate = 6 (12) b1 = self.sep_conv_bn( x, 256, "aspp1", rate=atrous_rates[0], depth_activation=True, epsilon=1e-5 ) # rate = 12 (24) b2 = self.sep_conv_bn( x, 256, "aspp2", rate=atrous_rates[1], depth_activation=True, epsilon=1e-5 ) # rate = 18 (36) b3 = self.sep_conv_bn( x, 256, "aspp3", rate=atrous_rates[2], depth_activation=True, epsilon=1e-5 ) # concatenate ASPP branches & project x = Concatenate()([b4, b0, b1, b2, b3]) else: x = Concatenate()([b4, b0]) x = Conv2D( 256, (1, 1), padding="same", use_bias=False, kernel_initializer=self._kernel_initializer(), name="concat_projection", )(x) x = BatchNormalization(name="concat_projection_BN", epsilon=1e-5)(x) x = Activation("relu")(x) x = Dropout(dropout_rate, seed=self._get_seed())(x) # DeepLab v.3+ decoder if backbone == "xception": # Feature projection # x4 (x2) block x = BilinearUpsampling( output_size=(int(np.ceil(input_shape[0] / 4)), int(np.ceil(input_shape[1] / 4))) )(x) dec_skip1 = Conv2D( 48, (1, 1), padding="same", use_bias=False, kernel_initializer=self._kernel_initializer(), name="feature_projection0", )(skip1) dec_skip1 = BatchNormalization(name="feature_projection0_BN", epsilon=1e-5)(dec_skip1) dec_skip1 = Activation("relu")(dec_skip1) x = Concatenate()([x, dec_skip1]) x = self.sep_conv_bn(x, 256, "decoder_conv0", depth_activation=True, epsilon=1e-5) x = self.sep_conv_bn(x, 256, "decoder_conv1", depth_activation=True, epsilon=1e-5) # you can use it with arbitary number of classes if classes == 21: last_layer_name = "logits_semantic" else: last_layer_name = "custom_logits_semantic" x = Conv2D( classes, (1, 1), padding="same", kernel_initializer=self._kernel_initializer(), name=last_layer_name, )(x) x = BilinearUpsampling(output_size=(input_shape[0], input_shape[1]))(x) # 1x1 convolution to get pixel-wise semantic segmentation. x = utils.add_sem_seg_activation( x, classes, conv_type=Conv2D, activation=self._cfg.LOSS[1], kernel_initializer="glorot_uniform", seed=self._cfg.SEED, ) # Ensure that the model takes into account # any potential predecessors of `input_tensor`. if input_tensor is not None: inputs = get_source_inputs(input_tensor) else: inputs = img_input model = Model(inputs, x, name="deeplabv3+")
return model
[docs] def relu6(self, x):
return K.relu(x, max_value=6)
[docs] def build_model(self, input_tensor=None) -> Model: cfg: DeeplabV3Config = self._cfg return self.deeplabv3( input_tensor=input_tensor, input_shape=cfg.IMG_SIZE, classes=cfg.get_num_classes(), OS=cfg.OS, dil_rate_input=cfg.DIL_RATES, dropout_rate=cfg.DROPOUT_RATE, backbone="xception",
)