Source code for medsegpy.modeling.meta_arch.build

import warnings
from abc import ABC, abstractmethod

from fvcore.common.registry import Registry

from medsegpy.config import Config

from ..model import Model

META_ARCH_REGISTRY = Registry("META_ARCH")
META_ARCH_REGISTRY.__doc__ = """
Registry for meta-architectures, i.e. the whole model.

The registered object will be called with `obj(cfg)`. The resulting object
should be duck typed with `build_model(input_tensor)`.
"""

_MODEL_MAP = {
    ("unet_2d", "unet_2_5d"): "UNet2D",
    ("unet_3d",): "UNet3D",
    ("deeplabv3_2d", "deeplabv3_2_5d", "deeplabv3+"): "DeeplabV3Plus",
}

LEGACY_MODEL_NAMES = {x: v for k, v in _MODEL_MAP.items() for x in k}


[docs]def build_model(cfg, input_tensor=None) -> Model: """ Build the whole model architecture, defined by ``cfg.MODEL_NAME``. Note that it does not load any weights from ``cfg``. """ name = cfg.MODEL_NAME try: META_ARCH_REGISTRY.get(name) except KeyError: # Legacy code used different tags for building models. prev_name = name if name in LEGACY_MODEL_NAMES: name = LEGACY_MODEL_NAMES[name] if prev_name != name: warnings.warn("MODEL_NAME {} is deprecated. Use {} instead".format(prev_name, name)) builder = META_ARCH_REGISTRY.get(name)(cfg) model = builder.build_model(input_tensor) assert isinstance(model, Model), ( "ModelBuilder.build_model should output model of type " "medsegpy.modeling.Model" )
return model class ModelBuilder(ABC): def __init__(self, cfg: Config): self._cfg = cfg @abstractmethod def build_model(self, input_tensor=None) -> Model: """Build model.""" pass