Source code for medsegpy.data.build

"""Build dataset dictionaries."""
import itertools
import logging
from typing import Dict, Hashable, List, Sequence, Union

from medsegpy.config import Config

from .catalog import DatasetCatalog
from .data_loader import build_data_loader


[docs]def filter_dataset( dataset_dicts: List[Dict], by: Hashable, accepted_elements, include_missing: bool = False ): """Filter by common dataset fields. Args: dataset_dicts (List[Dict]): data in MedSegPy Dataset format. by (Hashable): Field to filter by. accepted_elements (Sequence): Acceptable elements. include_missing (bool, optional): If `True`, include elements without `by` field in dictionary representation. Returns: List[Dict]: Filtered dataset dictionaries. """ num_before = len(dataset_dicts) dataset_dicts = [ x for x in dataset_dicts if include_missing or (by in x and x[by] in accepted_elements) ] num_after = len(dataset_dicts) logger = logging.getLogger(__name__) logger.info( "Removed {} elements with filter '{}'. {} elements left.".format( num_before - num_after, by, num_after ) )
return dataset_dicts
[docs]def get_sem_seg_dataset_dicts(dataset_names: Sequence[str], filter_empty: bool = True): """Load and prepare dataset dicts for semantic segmentation. Args: dataset_names (Sequence[str])": A list of dataset names. filter_empty (bool, optional): Filter datasets without field `'sem_seg_file'`. """ assert len(dataset_names) dataset_dicts = [DatasetCatalog.get(dataset_name) for dataset_name in dataset_names] for dataset_name, dicts in zip(dataset_names, dataset_dicts): assert len(dicts), "Dataset '{}' is empty!".format(dataset_name) dataset_dicts = list(itertools.chain.from_iterable(dataset_dicts)) logger = logging.getLogger(__name__) if filter_empty: num_before = len(dataset_dicts) dataset_dicts = [x for x in dataset_dicts if "sem_seg_file" in x] num_after = len(dataset_dicts) logger.info( "Removed {} elements without annotations. {} elements left.".format( num_before - num_after, num_after ) )
return dataset_dicts def build_loader( cfg: Config, dataset_names: Union[str, Sequence[str]], batch_size: int, is_test: bool, shuffle: bool, drop_last: bool, **kwargs, ): if isinstance(dataset_names, str): dataset_names = (dataset_names,) kwargs["batch_size"] = batch_size kwargs["is_test"] = is_test kwargs["shuffle"] = shuffle kwargs["drop_last"] = drop_last dataset_dicts = get_sem_seg_dataset_dicts(dataset_names, filter_empty=True) return build_data_loader(cfg, dataset_dicts, **kwargs)