import logging
import math
import random
import threading
import time
import warnings
from abc import ABC, abstractmethod
from collections import defaultdict
from typing import Dict, List, Sequence
import h5py
import keras.backend as K
import numpy as np
from fvcore.common.registry import Registry
from keras import utils as k_utils
from tqdm import tqdm
from medsegpy.config import Config
from medsegpy.modeling import Model
from medsegpy.utils import env
from .data_utils import add_background_labels, collect_mask, compute_patches
from .transforms import apply_transform_gens, build_preprocessing
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
sh = logging.StreamHandler()
sh.setLevel(logging.INFO)
logger.addHandler(sh)
DATA_LOADER_REGISTRY = Registry("DATA_LOADER")
DATA_LOADER_REGISTRY.__doc__ = """
Registry for data loaders, which can be used with `model.fit_generator()` and
`model.predict_generator()`. The evaluator type should be registered with
dataset_dicts, cfg, and other extra parameters.
The registered object will be called with
`obj(dataset_dicts, cfg, **kwargs)`.
The call should return a :class:`DataLoader` object.
"""
_LEGACY_DATA_LOADER_MAP = {("oai_aug", "oai", "oai_2d", "oai_aug_2d"): "DefaultDataLoader"}
LEGACY_DATA_LOADER_NAMES = {x: v for k, v in _LEGACY_DATA_LOADER_MAP.items() for x in k}
[docs]def build_data_loader(cfg: Config, dataset_dicts: List[Dict], **kwargs) -> "DataLoader":
"""Get data loader based on config `TAG` or name, value."""
name = cfg.TAG
try:
data_loader_cls = DATA_LOADER_REGISTRY.get(name)
except KeyError:
prev_name = name
if name in LEGACY_DATA_LOADER_NAMES:
name = LEGACY_DATA_LOADER_NAMES[name]
if prev_name != name:
warnings.warn("TAG {} is deprecated. Use {} instead".format(prev_name, name))
data_loader_cls = DATA_LOADER_REGISTRY.get(name)
return data_loader_cls(cfg, dataset_dicts, **kwargs)
[docs]class DataLoader(k_utils.Sequence, ABC):
"""Data loader following :class:`keras.utils.Sequence` API.
Data loaders load data per batch in the following way:
1. Collate inputs and outputs
2. Optionally apply preprocessing
To avoid changing the order of the base list, we shuffle a list of indices
and query based on the index.
Data loaders in medsegpy also have the ability to yield inference results
per scan (see :meth:`inference`).
"""
ALIASES = []
[docs] def __init__(
self,
cfg: Config,
dataset_dicts: List[Dict],
is_test: bool = False,
shuffle: bool = True,
drop_last: bool = True,
batch_size: int = 1,
):
"""
Args:
cfg (Config): A config object.
dataset_dicts (List[Dict]): List of data in medsegpy dataset format.
is_test (:obj:`bool`, optional): If `True`, configures loader as a
testing/inference loader. This is typically used when running
evaluation.
shuffle (bool, optional): If `True`, shuffle data every epoch.
drop_last (:obj:`bool`, optional): Drop the last incomplete batch,
if the dataset size is not divisible by the batch size. If
`False` and the size of the dataset is not divisible by batch
size, then the last batch will be smaller. This can affect
loss calculations.
batch_size (:obj:`int`, optional): Batch size.
"""
self._cfg = cfg
self._dataset_dicts = dataset_dicts
self.shuffle = shuffle
seed = cfg.SEED
self._random = random.Random(seed) if seed else random.Random()
self.drop_last = drop_last
self._batch_size = batch_size
self._category_idxs = cfg.CATEGORIES
self._is_test = is_test
self._idxs = list(range(0, self._num_elements()))
if shuffle:
self._random.shuffle(self._idxs)
def on_epoch_end(self):
if self.shuffle:
self._random.shuffle(self._idxs)
[docs] def __len__(self):
"""Number of batches.
By default, each element in the dataset dict is independent.
"""
_num_elements = self._num_elements()
if not self._is_test and self.drop_last:
return _num_elements // self._batch_size
else:
return math.ceil(_num_elements / self._batch_size)
def _num_elements(self):
"""Number of elements in the data loader."""
return len(self._dataset_dicts)
def num_scans(self):
return len({x["scan_id"] for x in self._dataset_dicts})
[docs] @abstractmethod
def inference(self, model, **kwargs):
"""Yields dictionaries of inputs, outputs per scan.
In medical settings, data is often processed per scan, not necessarily
per example. This distinction is critical. For example, a 2D
segmentation network may take in 2D slices of a scan as input.
However, during inference, it is standard to compute metrics on the full
scan, not individual slices.
This method does the following:
1. Loads dataset dicts corresponding to a scan
2. Structures data from these dicts
3. Runs predictions on the structured data
4. Restructures inputs. Images/volumes are restructured to HxWx...
Segmentation masks and predictions are restructured to
HxWx...xC.
5. Yield input, output dictionaries for the scan. Yielding continues
until all scans have been processed.
This method should yield scan-specific inputs and outputs as
dictionaries. The following keys should be in the `input` and `output`
dictionaries for each scan at minimum.
Input keys:
* "scan_id" (str): the scan identifier
* "x" (ndarray): the raw (unprocessed) input. Shape HxWx...
If the network takes multiple inputs, each input should
correspond to a unique key that will be handled by your
specified evaluator.
* "scan_XXX" (optional) scan-related parameters that will simplify
evaluation. e.g. "scan_spacing". MedSegPy evaluators will
default to scan specific information, if provided. For example,
if "scan_spacing" is specified, the value specified will
override the default spacing for the dataset.
* "subject_id" (optional): the subject identifier for the scan.
Useful for grouping results by subject.
Output keys:
* "time_elapsed" (required): Amount of time required for inference
on scan.
This quantity typically includes data loading time as well.
* "y_true" (ndarray): Ground truth binary mask for semantic
segmentation. Shape HxWx...xC.
Required for semantic segmentation inference.
* "y_pred" (ndarray): Prediction probabilities for semantic
segmentation. Shape HxWx...xC.
Required for semantic segmentation inference.
All output keys except "time_elapsed" are optional and task specific.
Args:
model: A model to run inference on.
kwargs: Keyword arguments to `model.predict_generator()`
Yields:
dict, dict: Dictionaries of inputs and outputs corresponding to a
single scan.
"""
yield {}, {}
[docs]@DATA_LOADER_REGISTRY.register()
class DefaultDataLoader(DataLoader):
"""The default data loader functionality in medsegy.
This class takes a dataset dict in the MedSegPy 2D Dataset format and maps
it to a format that can be used by the model for semantic segmentation.
This is the default data loader.
1. Read the input matrix from "file_name"
2. Read the ground truth mask matrix from "sem_seg_file_name"
3. If needed:
a. Add binary labels for background
4. Apply :class:`MedTransform` transforms to input and masks.
5. If training, return input (preprocessed), output.
If testing, return input (preprocessed), output, input (raw).
The testing structure is useful for tracking the original input
without any preprocessing. This return structure does not conflict with
existing Keras model functionality.
"""
def __init__(
self,
cfg: Config,
dataset_dicts: List[Dict],
is_test: bool = False,
shuffle: bool = True,
drop_last: bool = False,
batch_size: int = 1,
):
super().__init__(cfg, dataset_dicts, is_test, shuffle, drop_last, batch_size)
self._include_background = cfg.INCLUDE_BACKGROUND
self._num_classes = cfg.get_num_classes()
self._transform_gen = build_preprocessing(cfg)
self._cached_data = None
def _load_input(self, dataset_dict):
image_file = dataset_dict["file_name"]
sem_seg_file = dataset_dict.get("sem_seg_file", None)
if self._cached_data is not None:
image, mask = self._cached_data[(image_file, sem_seg_file)]
else:
with h5py.File(image_file, "r") as f:
image = f["data"][:]
if image.shape[-1] != 1:
image = image[..., np.newaxis]
if sem_seg_file:
with h5py.File(sem_seg_file, "r") as f:
mask = f["data"][:]
cat_idxs = self._category_idxs
mask = collect_mask(mask, index=cat_idxs)
if self._include_background:
mask = add_background_labels(mask)
else:
mask = None
return image, mask
def _load_batch(self, idxs: Sequence[int]):
"""
TODO: run test to determine if casting inputs/outputs is required.
"""
dataset_dicts = self._dataset_dicts
images = []
masks = []
for file_idx in idxs:
dataset_dict = dataset_dicts[file_idx]
image, mask = self._load_input(dataset_dict)
images.append(image)
masks.append(mask)
return (
np.stack(images, axis=0).astype(K.floatx()),
np.stack(masks, axis=0).astype(K.floatx()),
)
def _preprocess(self, inputs: np.ndarray, outputs: np.ndarray):
img, transforms = apply_transform_gens(self._transform_gen, inputs)
outputs = transforms.apply_segmentation(outputs)
return img, outputs
[docs] def __getitem__(self, idx):
"""
Args:
idx: Batch index.
Returns:
ndarray, ndarray: images NxHxWx(...)x1, masks NxHxWx(...)x1
"""
batch_size = self._batch_size
start = idx * batch_size
stop = min((idx + 1) * batch_size, self._num_elements())
inputs, outputs = self._load_batch(self._idxs[start:stop])
inputs_preprocessed, outputs = self._preprocess(inputs, outputs)
if self._is_test:
return inputs_preprocessed, outputs, inputs
else:
return inputs_preprocessed, outputs
def _restructure_data(self, vols: Sequence[np.ndarray]):
"""By default the batch dimension is moved to be the third dimension.
TODO: Change signature to specify if it is a segmentation volume or
image volume. Downstream data loaders need to distinguish between the
two (i.e. 2.5D networks).
Args:
vols (ndarrays): Shapes of NxHxWx...
Returns:
vols (ndarrays): Shapes of HxWxNx...
"""
new_vols = []
for v in vols:
axes = (1, 2, 0)
if v.ndim > 3:
axes = axes + tuple(i for i in range(3, v.ndim))
new_vols.append(v.transpose(axes))
vols = (np.squeeze(v) for v in new_vols)
return tuple(vols)
def inference(self, model: Model, **kwargs):
scan_to_dict_mapping = defaultdict(list)
for d in self._dataset_dicts:
scan_to_dict_mapping[d["scan_id"]].append(d)
scan_ids = sorted(scan_to_dict_mapping.keys())
dataset_dicts = self._dataset_dicts
workers = kwargs.pop("workers", self._cfg.NUM_WORKERS)
use_multiprocessing = kwargs.pop("use_multiprocessing", workers > 1)
for scan_id in scan_ids:
self._dataset_dicts = scan_to_dict_mapping[scan_id]
start = time.perf_counter()
if not isinstance(model, Model):
if not env.is_tf2():
raise ValueError("model must be a medsegpy.modeling.model.Model for TF1.0")
x, y, preds = Model.inference_generator_static(
model, self, workers=workers, use_multiprocessing=use_multiprocessing, **kwargs
)
else:
x, y, preds = model.inference_generator(
self, workers=workers, use_multiprocessing=use_multiprocessing, **kwargs
)
time_elapsed = time.perf_counter() - start
x, y, preds = self._restructure_data((x, y, preds))
input = {"x": x, "scan_id": scan_id}
scan_params = {
k: v
for k, v in self._dataset_dicts[0].items()
if isinstance(k, str) and k.startswith("scan")
}
input.update(scan_params)
output = {"y_pred": preds, "y_true": y, "time_elapsed": time_elapsed}
yield input, output
self._dataset_dicts = dataset_dicts
_SUPPORTED_PADDING_MODES = ("constant", "edge", "reflect", "symmetric", "warp", "empty")
[docs]@DATA_LOADER_REGISTRY.register()
class PatchDataLoader(DefaultDataLoader):
"""
This data loader pre-computes patch locations and padding based on
patch size (`cfg.IMG_SIZE`), pad type (`cfg.IMG_PAD_MODE`), pad size
(`cfg.IMG_PAD_SIZE`), and stride (`cfg.IMG_STRIDE`) parameters specified
in the config.
Assumptions:
* all dataset dictionaries have the same image dimensions
* "image_size" in dataset dict
"""
def __init__(
self,
cfg: Config,
dataset_dicts: List[Dict],
is_test: bool = False,
shuffle: bool = True,
drop_last: bool = False,
batch_size: int = 1,
use_singlefile: bool = False,
):
# Create patch elements from dataset dict.
# TODO: change pad/patching based on test/train
self._use_singlefile = use_singlefile
expected_img_dim = len(dataset_dicts[0]["image_size"])
img_dim = len(cfg.IMG_SIZE)
self._add_dim = False
if img_dim > expected_img_dim:
assert img_dim - expected_img_dim == 1
patch_size = cfg.IMG_SIZE[:-1]
self._add_dim = True
elif len(cfg.IMG_SIZE) == expected_img_dim:
patch_size = cfg.IMG_SIZE
else:
extra_dims = (1,) * (expected_img_dim - img_dim)
patch_size = tuple(cfg.IMG_SIZE) + extra_dims
self._patch_size = patch_size
self._pad_mode = cfg.IMG_PAD_MODE
if self._pad_mode not in _SUPPORTED_PADDING_MODES:
raise ValueError("pad mode {} not supported".format(cfg.IMG_PAD_MODE))
pad_size = cfg.IMG_PAD_SIZE if cfg.IMG_PAD_SIZE else None
stride = cfg.IMG_STRIDE if cfg.IMG_STRIDE else (1,) * len(patch_size)
dd_patched = []
for dd in dataset_dicts:
patches = compute_patches(dd["image_size"], self._patch_size, pad_size, stride)
if len(patches) == 0:
logger.warn(f"Dropping {dd['scan_id']} - no patches found.")
for patch, pad in patches:
dataset_dict = dd.copy()
dataset_dict.update({"_patch": patch, "_pad": pad})
dd_patched.append(dataset_dict)
super().__init__(cfg, dd_patched, is_test, shuffle, drop_last, batch_size)
self._preload_data = cfg.PRELOAD_DATA
self._cached_data = None
self._f = None
if self._use_singlefile:
self._singlefile_fp = dataset_dicts[0]["singlefile_path"]
if self._preload_data:
if threading.current_thread() is not threading.main_thread():
raise ValueError("Data pre-loading can only be done on the main thread.")
logger.info("Pre-loading data...")
self._cached_data = self._load_all_data(dataset_dicts, cfg.NUM_WORKERS)
def __del__(self):
if hasattr(self, "_f") and self._f is not None:
self._f.close()
[docs] def __getitem__(self, idx):
"""
Args:
idx: Batch index.
Returns:
ndarray, ndarray: images NxHxWx(...)x1, masks NxHxWx(...)x1
"""
if self._use_singlefile and self._f is None:
self._f = h5py.File(self._singlefile_fp, "r")
batch_size = self._batch_size
start = idx * batch_size
stop = min((idx + 1) * batch_size, self._num_elements())
inputs, outputs = self._load_batch(self._idxs[start:stop])
inputs_preprocessed, outputs = self._preprocess(inputs, outputs)
if self._is_test:
return inputs_preprocessed, outputs, inputs
else:
return inputs_preprocessed, outputs
def _load_all_data(self, dataset_dicts, num_workers: int = 1) -> Dict:
"""
We assume that that the tuple `("file_name", "sem_seg_file")`
is sufficient for determining the uniqueness of each base dataset
dictionary.
"""
def _load(dataset_dict):
image, mask = self._load_patch(dataset_dict, skip_patch=True)
if set(np.unique(mask)) == {0, 1}:
mask = mask.astype(np.bool)
return {"image": image, "mask": mask}
cache = [_load(dd) for dd in tqdm(dataset_dicts)]
cache = {(dd["file_name"], dd["sem_seg_file"]): x for dd, x in zip(dataset_dicts, cache)}
return cache
def _load_patch(self, dataset_dict, skip_patch: bool = False, img_key=None, seg_key=None):
image_file = dataset_dict["file_name"]
sem_seg_file = dataset_dict.get("sem_seg_file", None)
patch = Ellipsis if skip_patch else dataset_dict["_patch"]
mask = None
is_img_seg_file_same = image_file == sem_seg_file
if seg_key is None:
seg_key = "seg" if is_img_seg_file_same else "data"
if img_key is None:
img_key = "volume" if is_img_seg_file_same else "data"
# Load data from one h5 file if self._use_singlefile
if not self._use_singlefile:
f = h5py.File(image_file, "r")
image = f[img_key][patch] # HxWxDx...
if sem_seg_file and is_img_seg_file_same:
mask = f[seg_key][patch] # HxWxDx...xC
else:
s = h5py.File(sem_seg_file, "r")
mask = s[seg_key][patch]
s.close()
f.close()
else:
image = self._f[image_file][img_key][patch]
if sem_seg_file:
mask = self._f[image_file][seg_key][patch]
if mask is not None:
cat_idxs = self._category_idxs
mask = collect_mask(mask, index=cat_idxs)
if self._include_background:
mask = add_background_labels(mask)
return image, mask
def _load_input(self, dataset_dict):
if self._cached_data is not None:
patch = dataset_dict["_patch"]
image_file = dataset_dict["file_name"]
sem_seg_file = dataset_dict.get("sem_seg_file", None)
data = self._cached_data[(image_file, sem_seg_file)]
image, mask = data["image"], data["mask"]
image, mask = image[patch], mask[patch]
else:
image, mask = self._load_patch(dataset_dict)
pad = dataset_dict["_pad"]
if pad is not None:
image = np.pad(image, pad, self._pad_mode)
if mask is not None:
mask = np.pad(mask, tuple(pad) + ((0, 0),), self._pad_mode)
if self._add_dim:
image = image[..., np.newaxis]
# mask = mask[..., np.newaxis, :]
return image, mask
def _restructure_data(self, vols_patched: Sequence[np.ndarray]):
"""By default the batch dimension is moved to be the third dimension.
This method assumes that `self._dataset_dicts` is limited to dataset
dictionaries for only one scan. It also assumes that the order of
each patch in `vols_patches` is ordered based on the dataset dictionary.
Args:
vols_patched (ndarrays): Each has shape of NxP1xP2x...
Returns:
vols (ndarrays): Shapes of HxWxDx...
"""
assert self._is_test
image_size = self._dataset_dicts[0]["image_size"]
coords = [dd["_patch"] for dd in self._dataset_dicts]
num_patches = vols_patched[0].shape[0]
assert len(coords) == num_patches, "{} patches, {} coords".format(num_patches, len(coords))
# num_vols = len(vols_patched)
# TODO: fix in case that v.shape[-1] is not actually a channel dimension
new_vols = [
np.zeros(tuple(image_size) + (v.shape[-1],)) for v in vols_patched
] # VxNxHxWx...
for idx, c in enumerate(coords):
for vol_id in range(len(new_vols)):
# Hacky solution to handle extra axis dimension, if exists.
x = vols_patched[vol_id][idx]
if x.ndim == new_vols[vol_id][c].ndim - 1:
x = x[..., np.newaxis, :]
new_vols[vol_id][c] = x
return tuple(new_vols)
[docs]@DATA_LOADER_REGISTRY.register()
class N5dDataLoader(PatchDataLoader):
"""n.5D data loader.
Use this for 2.5D, 3.5D, etc. implementations.
Currently only last dimension is supported as the channel dimension.
"""
def __init__(
self,
cfg: Config,
dataset_dicts: List[Dict],
is_test: bool = False,
shuffle: bool = True,
drop_last: bool = False,
batch_size: int = 1,
):
expected_img_dim = len(dataset_dicts[0]["image_size"])
img_dim = len(cfg.IMG_SIZE)
if img_dim != expected_img_dim:
raise ValueError(
"Data has {} dimensions. cfg.IMG_SIZE is {} dimensions".format(
expected_img_dim, img_dim
)
)
if cfg.IMG_SIZE[-1] % 2 != 1:
raise ValueError("channel dimension must be odd")
super().__init__(cfg, dataset_dicts, is_test, shuffle, drop_last, batch_size)
def _load_input(self, dataset_dict):
image, mask = super()._load_input(dataset_dict)
dim = mask.shape[-2]
mask = mask[..., dim // 2, :]
return image, mask
[docs]@DATA_LOADER_REGISTRY.register()
class S25dDataLoader(DefaultDataLoader):
"""Special case of 2.5D data loader compatible with 2D MedSegPy data format.
Each dataset dict should represent a slice and must have the additional
keys:
- "slice_id" (int): Slice id (1-indexed) that the dataset corresponds to.
- "scan_num_slices" (int): Number of total slices in the scan that the
dataset dict is derived from
Padding is automatically applied to ensure all slices are considered.
This is a temporary solution until the slow loading speeds of the
:class:`N5dDataLoader` are properly debugged.
"""
def __init__(
self,
cfg: Config,
dataset_dicts: List[Dict],
is_test: bool = False,
shuffle: bool = True,
drop_last: bool = False,
batch_size: int = 1,
):
self._window = cfg.IMG_SIZE[-1]
assert len(cfg.IMG_SIZE) == 3
assert self._window % 2 == 1
self._pad_mode = cfg.IMG_PAD_MODE
# Create a mapping from scan_id to list of dataset dicts in order of
# slice.
# TODO: remove copying dictionaries if runtime speed issues found
mapping = defaultdict(list)
sorted_dataset_dicts = sorted(dataset_dicts, key=lambda d: (d["scan_id"], d["slice_id"]))
for dd in sorted_dataset_dicts:
mapping[dd["scan_id"]].append(dd)
for scan_id, dds in mapping.items():
slice_order = [dd["slice_id"] for dd in dds]
assert (
sorted(slice_order) == slice_order
), "Error in sorting dataset dictionaries " "for scan {} by slice id".format(scan_id)
self._scan_to_dicts = mapping
super().__init__(cfg, dataset_dicts, is_test, shuffle, drop_last, batch_size)
def _load_input(self, dataset_dict):
"""Find dataset dicts corresponding to flanking/neighboring slices and
load.
"""
slice_id = dataset_dict["slice_id"] # 1-indexed
scan_id = dataset_dict["scan_id"]
total_num_slices = dataset_dict["scan_num_slices"] # 1-indexed
num_flank_slices = self._window // 2
l_pad = r_pad = 0
if total_num_slices - slice_id < num_flank_slices:
# Right pad the volume.
r_pad = num_flank_slices - (total_num_slices - slice_id)
if slice_id - num_flank_slices <= 0:
# Left pad the volume.
l_pad = num_flank_slices - slice_id + 1
pad = ((0, 0), (0, 0), (l_pad, r_pad)) if l_pad or r_pad else None
# Load images for neighboring slices.
idx = slice_id - 1
start = max(0, idx - num_flank_slices)
end = min(total_num_slices, idx + 1 + num_flank_slices)
dataset_dicts = self._scan_to_dicts[scan_id][start:end]
images = []
for dd in dataset_dicts:
image_file = dd["file_name"]
with h5py.File(image_file, "r") as f:
image = f["data"][:]
if image.shape[-1] == 1:
image = np.squeeze(image)
images.append(image)
image = np.stack(images, axis=-1)
if pad is not None:
image = np.pad(image, pad, self._pad_mode)
# Load segmentation only for center slice
sem_seg_file = dataset_dict.get("sem_seg_file", None)
if sem_seg_file:
with h5py.File(sem_seg_file, "r") as f:
mask = f["data"][:]
cat_idxs = self._category_idxs
mask = collect_mask(mask, index=cat_idxs)
if self._include_background:
mask = add_background_labels(mask)
else:
mask = None
return image, mask
def _restructure_data(self, vols: Sequence[np.ndarray]):
"""By default the batch dimension is moved to be the third dimension.
This method assumes that `self._dataset_dicts` is limited to dataset
dictionaries for only one scan. It also assumes that the order of
each patch in `vols_patches` is ordered based on the dataset dictionary.
Args:
vols_patched (ndarrays): Each has shape of NxP1xP2x...
Returns:
vols (ndarrays): Shapes of HxWxDx...
"""
x, y, preds = vols
x = x[..., self._window // 2]
assert x.ndim == 3, "NxHxW"
return super()._restructure_data((x, y, preds))