Source code for medsegpy.data.data_loader

import logging
import math
import random
import threading
import time
import warnings
from abc import ABC, abstractmethod
from collections import defaultdict
from typing import Dict, List, Sequence

import h5py
import keras.backend as K
import numpy as np
from fvcore.common.registry import Registry
from keras import utils as k_utils
from tqdm import tqdm

from medsegpy.config import Config
from medsegpy.modeling import Model
from medsegpy.utils import env

from .data_utils import add_background_labels, collect_mask, compute_patches
from .transforms import apply_transform_gens, build_preprocessing

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
sh = logging.StreamHandler()
sh.setLevel(logging.INFO)
logger.addHandler(sh)

DATA_LOADER_REGISTRY = Registry("DATA_LOADER")
DATA_LOADER_REGISTRY.__doc__ = """
Registry for data loaders, which can be used with `model.fit_generator()` and
`model.predict_generator()`. The evaluator type should be registered with
dataset_dicts, cfg, and other extra parameters.

The registered object will be called with
`obj(dataset_dicts, cfg, **kwargs)`.
The call should return a :class:`DataLoader` object.
"""

_LEGACY_DATA_LOADER_MAP = {("oai_aug", "oai", "oai_2d", "oai_aug_2d"): "DefaultDataLoader"}

LEGACY_DATA_LOADER_NAMES = {x: v for k, v in _LEGACY_DATA_LOADER_MAP.items() for x in k}


[docs]def build_data_loader(cfg: Config, dataset_dicts: List[Dict], **kwargs) -> "DataLoader": """Get data loader based on config `TAG` or name, value.""" name = cfg.TAG try: data_loader_cls = DATA_LOADER_REGISTRY.get(name) except KeyError: prev_name = name if name in LEGACY_DATA_LOADER_NAMES: name = LEGACY_DATA_LOADER_NAMES[name] if prev_name != name: warnings.warn("TAG {} is deprecated. Use {} instead".format(prev_name, name)) data_loader_cls = DATA_LOADER_REGISTRY.get(name)
return data_loader_cls(cfg, dataset_dicts, **kwargs)
[docs]class DataLoader(k_utils.Sequence, ABC): """Data loader following :class:`keras.utils.Sequence` API. Data loaders load data per batch in the following way: 1. Collate inputs and outputs 2. Optionally apply preprocessing To avoid changing the order of the base list, we shuffle a list of indices and query based on the index. Data loaders in medsegpy also have the ability to yield inference results per scan (see :meth:`inference`). """ ALIASES = []
[docs] def __init__( self, cfg: Config, dataset_dicts: List[Dict], is_test: bool = False, shuffle: bool = True, drop_last: bool = True, batch_size: int = 1, ): """ Args: cfg (Config): A config object. dataset_dicts (List[Dict]): List of data in medsegpy dataset format. is_test (:obj:`bool`, optional): If `True`, configures loader as a testing/inference loader. This is typically used when running evaluation. shuffle (bool, optional): If `True`, shuffle data every epoch. drop_last (:obj:`bool`, optional): Drop the last incomplete batch, if the dataset size is not divisible by the batch size. If `False` and the size of the dataset is not divisible by batch size, then the last batch will be smaller. This can affect loss calculations. batch_size (:obj:`int`, optional): Batch size. """ self._cfg = cfg self._dataset_dicts = dataset_dicts self.shuffle = shuffle seed = cfg.SEED self._random = random.Random(seed) if seed else random.Random() self.drop_last = drop_last self._batch_size = batch_size self._category_idxs = cfg.CATEGORIES self._is_test = is_test self._idxs = list(range(0, self._num_elements())) if shuffle:
self._random.shuffle(self._idxs) def on_epoch_end(self): if self.shuffle: self._random.shuffle(self._idxs)
[docs] def __len__(self): """Number of batches. By default, each element in the dataset dict is independent. """ _num_elements = self._num_elements() if not self._is_test and self.drop_last: return _num_elements // self._batch_size else:
return math.ceil(_num_elements / self._batch_size) def _num_elements(self): """Number of elements in the data loader.""" return len(self._dataset_dicts) def num_scans(self): return len({x["scan_id"] for x in self._dataset_dicts})
[docs] @abstractmethod def inference(self, model, **kwargs): """Yields dictionaries of inputs, outputs per scan. In medical settings, data is often processed per scan, not necessarily per example. This distinction is critical. For example, a 2D segmentation network may take in 2D slices of a scan as input. However, during inference, it is standard to compute metrics on the full scan, not individual slices. This method does the following: 1. Loads dataset dicts corresponding to a scan 2. Structures data from these dicts 3. Runs predictions on the structured data 4. Restructures inputs. Images/volumes are restructured to HxWx... Segmentation masks and predictions are restructured to HxWx...xC. 5. Yield input, output dictionaries for the scan. Yielding continues until all scans have been processed. This method should yield scan-specific inputs and outputs as dictionaries. The following keys should be in the `input` and `output` dictionaries for each scan at minimum. Input keys: * "scan_id" (str): the scan identifier * "x" (ndarray): the raw (unprocessed) input. Shape HxWx... If the network takes multiple inputs, each input should correspond to a unique key that will be handled by your specified evaluator. * "scan_XXX" (optional) scan-related parameters that will simplify evaluation. e.g. "scan_spacing". MedSegPy evaluators will default to scan specific information, if provided. For example, if "scan_spacing" is specified, the value specified will override the default spacing for the dataset. * "subject_id" (optional): the subject identifier for the scan. Useful for grouping results by subject. Output keys: * "time_elapsed" (required): Amount of time required for inference on scan. This quantity typically includes data loading time as well. * "y_true" (ndarray): Ground truth binary mask for semantic segmentation. Shape HxWx...xC. Required for semantic segmentation inference. * "y_pred" (ndarray): Prediction probabilities for semantic segmentation. Shape HxWx...xC. Required for semantic segmentation inference. All output keys except "time_elapsed" are optional and task specific. Args: model: A model to run inference on. kwargs: Keyword arguments to `model.predict_generator()` Yields: dict, dict: Dictionaries of inputs and outputs corresponding to a single scan. """
yield {}, {}
[docs]@DATA_LOADER_REGISTRY.register() class DefaultDataLoader(DataLoader): """The default data loader functionality in medsegy. This class takes a dataset dict in the MedSegPy 2D Dataset format and maps it to a format that can be used by the model for semantic segmentation. This is the default data loader. 1. Read the input matrix from "file_name" 2. Read the ground truth mask matrix from "sem_seg_file_name" 3. If needed: a. Add binary labels for background 4. Apply :class:`MedTransform` transforms to input and masks. 5. If training, return input (preprocessed), output. If testing, return input (preprocessed), output, input (raw). The testing structure is useful for tracking the original input without any preprocessing. This return structure does not conflict with existing Keras model functionality. """ def __init__( self, cfg: Config, dataset_dicts: List[Dict], is_test: bool = False, shuffle: bool = True, drop_last: bool = False, batch_size: int = 1, ): super().__init__(cfg, dataset_dicts, is_test, shuffle, drop_last, batch_size) self._include_background = cfg.INCLUDE_BACKGROUND self._num_classes = cfg.get_num_classes() self._transform_gen = build_preprocessing(cfg) self._cached_data = None def _load_input(self, dataset_dict): image_file = dataset_dict["file_name"] sem_seg_file = dataset_dict.get("sem_seg_file", None) if self._cached_data is not None: image, mask = self._cached_data[(image_file, sem_seg_file)] else: with h5py.File(image_file, "r") as f: image = f["data"][:] if image.shape[-1] != 1: image = image[..., np.newaxis] if sem_seg_file: with h5py.File(sem_seg_file, "r") as f: mask = f["data"][:] cat_idxs = self._category_idxs mask = collect_mask(mask, index=cat_idxs) if self._include_background: mask = add_background_labels(mask) else: mask = None return image, mask def _load_batch(self, idxs: Sequence[int]): """ TODO: run test to determine if casting inputs/outputs is required. """ dataset_dicts = self._dataset_dicts images = [] masks = [] for file_idx in idxs: dataset_dict = dataset_dicts[file_idx] image, mask = self._load_input(dataset_dict) images.append(image) masks.append(mask) return ( np.stack(images, axis=0).astype(K.floatx()), np.stack(masks, axis=0).astype(K.floatx()), ) def _preprocess(self, inputs: np.ndarray, outputs: np.ndarray): img, transforms = apply_transform_gens(self._transform_gen, inputs) outputs = transforms.apply_segmentation(outputs) return img, outputs
[docs] def __getitem__(self, idx): """ Args: idx: Batch index. Returns: ndarray, ndarray: images NxHxWx(...)x1, masks NxHxWx(...)x1 """ batch_size = self._batch_size start = idx * batch_size stop = min((idx + 1) * batch_size, self._num_elements()) inputs, outputs = self._load_batch(self._idxs[start:stop]) inputs_preprocessed, outputs = self._preprocess(inputs, outputs) if self._is_test: return inputs_preprocessed, outputs, inputs else:
return inputs_preprocessed, outputs def _restructure_data(self, vols: Sequence[np.ndarray]): """By default the batch dimension is moved to be the third dimension. TODO: Change signature to specify if it is a segmentation volume or image volume. Downstream data loaders need to distinguish between the two (i.e. 2.5D networks). Args: vols (ndarrays): Shapes of NxHxWx... Returns: vols (ndarrays): Shapes of HxWxNx... """ new_vols = [] for v in vols: axes = (1, 2, 0) if v.ndim > 3: axes = axes + tuple(i for i in range(3, v.ndim)) new_vols.append(v.transpose(axes)) vols = (np.squeeze(v) for v in new_vols) return tuple(vols) def inference(self, model: Model, **kwargs): scan_to_dict_mapping = defaultdict(list) for d in self._dataset_dicts: scan_to_dict_mapping[d["scan_id"]].append(d) scan_ids = sorted(scan_to_dict_mapping.keys()) dataset_dicts = self._dataset_dicts workers = kwargs.pop("workers", self._cfg.NUM_WORKERS) use_multiprocessing = kwargs.pop("use_multiprocessing", workers > 1) for scan_id in scan_ids: self._dataset_dicts = scan_to_dict_mapping[scan_id] start = time.perf_counter() if not isinstance(model, Model): if not env.is_tf2(): raise ValueError("model must be a medsegpy.modeling.model.Model for TF1.0") x, y, preds = Model.inference_generator_static( model, self, workers=workers, use_multiprocessing=use_multiprocessing, **kwargs ) else: x, y, preds = model.inference_generator( self, workers=workers, use_multiprocessing=use_multiprocessing, **kwargs ) time_elapsed = time.perf_counter() - start x, y, preds = self._restructure_data((x, y, preds)) input = {"x": x, "scan_id": scan_id} scan_params = { k: v for k, v in self._dataset_dicts[0].items() if isinstance(k, str) and k.startswith("scan") } input.update(scan_params) output = {"y_pred": preds, "y_true": y, "time_elapsed": time_elapsed} yield input, output
self._dataset_dicts = dataset_dicts _SUPPORTED_PADDING_MODES = ("constant", "edge", "reflect", "symmetric", "warp", "empty")
[docs]@DATA_LOADER_REGISTRY.register() class PatchDataLoader(DefaultDataLoader): """ This data loader pre-computes patch locations and padding based on patch size (`cfg.IMG_SIZE`), pad type (`cfg.IMG_PAD_MODE`), pad size (`cfg.IMG_PAD_SIZE`), and stride (`cfg.IMG_STRIDE`) parameters specified in the config. Assumptions: * all dataset dictionaries have the same image dimensions * "image_size" in dataset dict """ def __init__( self, cfg: Config, dataset_dicts: List[Dict], is_test: bool = False, shuffle: bool = True, drop_last: bool = False, batch_size: int = 1, use_singlefile: bool = False, ): # Create patch elements from dataset dict. # TODO: change pad/patching based on test/train self._use_singlefile = use_singlefile expected_img_dim = len(dataset_dicts[0]["image_size"]) img_dim = len(cfg.IMG_SIZE) self._add_dim = False if img_dim > expected_img_dim: assert img_dim - expected_img_dim == 1 patch_size = cfg.IMG_SIZE[:-1] self._add_dim = True elif len(cfg.IMG_SIZE) == expected_img_dim: patch_size = cfg.IMG_SIZE else: extra_dims = (1,) * (expected_img_dim - img_dim) patch_size = tuple(cfg.IMG_SIZE) + extra_dims self._patch_size = patch_size self._pad_mode = cfg.IMG_PAD_MODE if self._pad_mode not in _SUPPORTED_PADDING_MODES: raise ValueError("pad mode {} not supported".format(cfg.IMG_PAD_MODE)) pad_size = cfg.IMG_PAD_SIZE if cfg.IMG_PAD_SIZE else None stride = cfg.IMG_STRIDE if cfg.IMG_STRIDE else (1,) * len(patch_size) dd_patched = [] for dd in dataset_dicts: patches = compute_patches(dd["image_size"], self._patch_size, pad_size, stride) if len(patches) == 0: logger.warn(f"Dropping {dd['scan_id']} - no patches found.") for patch, pad in patches: dataset_dict = dd.copy() dataset_dict.update({"_patch": patch, "_pad": pad}) dd_patched.append(dataset_dict) super().__init__(cfg, dd_patched, is_test, shuffle, drop_last, batch_size) self._preload_data = cfg.PRELOAD_DATA self._cached_data = None self._f = None if self._use_singlefile: self._singlefile_fp = dataset_dicts[0]["singlefile_path"] if self._preload_data: if threading.current_thread() is not threading.main_thread(): raise ValueError("Data pre-loading can only be done on the main thread.") logger.info("Pre-loading data...") self._cached_data = self._load_all_data(dataset_dicts, cfg.NUM_WORKERS) def __del__(self): if hasattr(self, "_f") and self._f is not None: self._f.close()
[docs] def __getitem__(self, idx): """ Args: idx: Batch index. Returns: ndarray, ndarray: images NxHxWx(...)x1, masks NxHxWx(...)x1 """ if self._use_singlefile and self._f is None: self._f = h5py.File(self._singlefile_fp, "r") batch_size = self._batch_size start = idx * batch_size stop = min((idx + 1) * batch_size, self._num_elements()) inputs, outputs = self._load_batch(self._idxs[start:stop]) inputs_preprocessed, outputs = self._preprocess(inputs, outputs) if self._is_test: return inputs_preprocessed, outputs, inputs else:
return inputs_preprocessed, outputs def _load_all_data(self, dataset_dicts, num_workers: int = 1) -> Dict: """ We assume that that the tuple `("file_name", "sem_seg_file")` is sufficient for determining the uniqueness of each base dataset dictionary. """ def _load(dataset_dict): image, mask = self._load_patch(dataset_dict, skip_patch=True) if set(np.unique(mask)) == {0, 1}: mask = mask.astype(np.bool) return {"image": image, "mask": mask} cache = [_load(dd) for dd in tqdm(dataset_dicts)] cache = {(dd["file_name"], dd["sem_seg_file"]): x for dd, x in zip(dataset_dicts, cache)} return cache def _load_patch(self, dataset_dict, skip_patch: bool = False, img_key=None, seg_key=None): image_file = dataset_dict["file_name"] sem_seg_file = dataset_dict.get("sem_seg_file", None) patch = Ellipsis if skip_patch else dataset_dict["_patch"] mask = None is_img_seg_file_same = image_file == sem_seg_file if seg_key is None: seg_key = "seg" if is_img_seg_file_same else "data" if img_key is None: img_key = "volume" if is_img_seg_file_same else "data" # Load data from one h5 file if self._use_singlefile if not self._use_singlefile: f = h5py.File(image_file, "r") image = f[img_key][patch] # HxWxDx... if sem_seg_file and is_img_seg_file_same: mask = f[seg_key][patch] # HxWxDx...xC else: s = h5py.File(sem_seg_file, "r") mask = s[seg_key][patch] s.close() f.close() else: image = self._f[image_file][img_key][patch] if sem_seg_file: mask = self._f[image_file][seg_key][patch] if mask is not None: cat_idxs = self._category_idxs mask = collect_mask(mask, index=cat_idxs) if self._include_background: mask = add_background_labels(mask) return image, mask def _load_input(self, dataset_dict): if self._cached_data is not None: patch = dataset_dict["_patch"] image_file = dataset_dict["file_name"] sem_seg_file = dataset_dict.get("sem_seg_file", None) data = self._cached_data[(image_file, sem_seg_file)] image, mask = data["image"], data["mask"] image, mask = image[patch], mask[patch] else: image, mask = self._load_patch(dataset_dict) pad = dataset_dict["_pad"] if pad is not None: image = np.pad(image, pad, self._pad_mode) if mask is not None: mask = np.pad(mask, tuple(pad) + ((0, 0),), self._pad_mode) if self._add_dim: image = image[..., np.newaxis] # mask = mask[..., np.newaxis, :] return image, mask def _restructure_data(self, vols_patched: Sequence[np.ndarray]): """By default the batch dimension is moved to be the third dimension. This method assumes that `self._dataset_dicts` is limited to dataset dictionaries for only one scan. It also assumes that the order of each patch in `vols_patches` is ordered based on the dataset dictionary. Args: vols_patched (ndarrays): Each has shape of NxP1xP2x... Returns: vols (ndarrays): Shapes of HxWxDx... """ assert self._is_test image_size = self._dataset_dicts[0]["image_size"] coords = [dd["_patch"] for dd in self._dataset_dicts] num_patches = vols_patched[0].shape[0] assert len(coords) == num_patches, "{} patches, {} coords".format(num_patches, len(coords)) # num_vols = len(vols_patched) # TODO: fix in case that v.shape[-1] is not actually a channel dimension new_vols = [ np.zeros(tuple(image_size) + (v.shape[-1],)) for v in vols_patched ] # VxNxHxWx... for idx, c in enumerate(coords): for vol_id in range(len(new_vols)): # Hacky solution to handle extra axis dimension, if exists. x = vols_patched[vol_id][idx] if x.ndim == new_vols[vol_id][c].ndim - 1: x = x[..., np.newaxis, :] new_vols[vol_id][c] = x
return tuple(new_vols)
[docs]@DATA_LOADER_REGISTRY.register() class N5dDataLoader(PatchDataLoader): """n.5D data loader. Use this for 2.5D, 3.5D, etc. implementations. Currently only last dimension is supported as the channel dimension. """ def __init__( self, cfg: Config, dataset_dicts: List[Dict], is_test: bool = False, shuffle: bool = True, drop_last: bool = False, batch_size: int = 1, ): expected_img_dim = len(dataset_dicts[0]["image_size"]) img_dim = len(cfg.IMG_SIZE) if img_dim != expected_img_dim: raise ValueError( "Data has {} dimensions. cfg.IMG_SIZE is {} dimensions".format( expected_img_dim, img_dim ) ) if cfg.IMG_SIZE[-1] % 2 != 1: raise ValueError("channel dimension must be odd") super().__init__(cfg, dataset_dicts, is_test, shuffle, drop_last, batch_size) def _load_input(self, dataset_dict): image, mask = super()._load_input(dataset_dict) dim = mask.shape[-2] mask = mask[..., dim // 2, :]
return image, mask
[docs]@DATA_LOADER_REGISTRY.register() class S25dDataLoader(DefaultDataLoader): """Special case of 2.5D data loader compatible with 2D MedSegPy data format. Each dataset dict should represent a slice and must have the additional keys: - "slice_id" (int): Slice id (1-indexed) that the dataset corresponds to. - "scan_num_slices" (int): Number of total slices in the scan that the dataset dict is derived from Padding is automatically applied to ensure all slices are considered. This is a temporary solution until the slow loading speeds of the :class:`N5dDataLoader` are properly debugged. """ def __init__( self, cfg: Config, dataset_dicts: List[Dict], is_test: bool = False, shuffle: bool = True, drop_last: bool = False, batch_size: int = 1, ): self._window = cfg.IMG_SIZE[-1] assert len(cfg.IMG_SIZE) == 3 assert self._window % 2 == 1 self._pad_mode = cfg.IMG_PAD_MODE # Create a mapping from scan_id to list of dataset dicts in order of # slice. # TODO: remove copying dictionaries if runtime speed issues found mapping = defaultdict(list) sorted_dataset_dicts = sorted(dataset_dicts, key=lambda d: (d["scan_id"], d["slice_id"])) for dd in sorted_dataset_dicts: mapping[dd["scan_id"]].append(dd) for scan_id, dds in mapping.items(): slice_order = [dd["slice_id"] for dd in dds] assert ( sorted(slice_order) == slice_order ), "Error in sorting dataset dictionaries " "for scan {} by slice id".format(scan_id) self._scan_to_dicts = mapping super().__init__(cfg, dataset_dicts, is_test, shuffle, drop_last, batch_size) def _load_input(self, dataset_dict): """Find dataset dicts corresponding to flanking/neighboring slices and load. """ slice_id = dataset_dict["slice_id"] # 1-indexed scan_id = dataset_dict["scan_id"] total_num_slices = dataset_dict["scan_num_slices"] # 1-indexed num_flank_slices = self._window // 2 l_pad = r_pad = 0 if total_num_slices - slice_id < num_flank_slices: # Right pad the volume. r_pad = num_flank_slices - (total_num_slices - slice_id) if slice_id - num_flank_slices <= 0: # Left pad the volume. l_pad = num_flank_slices - slice_id + 1 pad = ((0, 0), (0, 0), (l_pad, r_pad)) if l_pad or r_pad else None # Load images for neighboring slices. idx = slice_id - 1 start = max(0, idx - num_flank_slices) end = min(total_num_slices, idx + 1 + num_flank_slices) dataset_dicts = self._scan_to_dicts[scan_id][start:end] images = [] for dd in dataset_dicts: image_file = dd["file_name"] with h5py.File(image_file, "r") as f: image = f["data"][:] if image.shape[-1] == 1: image = np.squeeze(image) images.append(image) image = np.stack(images, axis=-1) if pad is not None: image = np.pad(image, pad, self._pad_mode) # Load segmentation only for center slice sem_seg_file = dataset_dict.get("sem_seg_file", None) if sem_seg_file: with h5py.File(sem_seg_file, "r") as f: mask = f["data"][:] cat_idxs = self._category_idxs mask = collect_mask(mask, index=cat_idxs) if self._include_background: mask = add_background_labels(mask) else: mask = None return image, mask def _restructure_data(self, vols: Sequence[np.ndarray]): """By default the batch dimension is moved to be the third dimension. This method assumes that `self._dataset_dicts` is limited to dataset dictionaries for only one scan. It also assumes that the order of each patch in `vols_patches` is ordered based on the dataset dictionary. Args: vols_patched (ndarrays): Each has shape of NxP1xP2x... Returns: vols (ndarrays): Shapes of HxWxDx... """ x, y, preds = vols x = x[..., self._window // 2] assert x.ndim == 3, "NxHxW"
return super()._restructure_data((x, y, preds))