import copy
import logging
import os
import pickle
import warnings
from typing import Tuple
import tensorflow as tf
from keras import callbacks as kc
from keras.utils import plot_model
from medsegpy import config, solver
from medsegpy.data import build_loader, im_gens
from medsegpy.engine.callbacks import LossHistory, WandBLogger, lr_callback
from medsegpy.evaluation import build_evaluator, inference_on_dataset
from medsegpy.losses import build_loss, dice_loss
from medsegpy.modeling.meta_arch import build_model
from medsegpy.utils import dl_utils, env, io_utils
try:
_SUPPORTS_DISTRIBUTED = True
from tensorflow.distribute import MirroredStrategy
except ModuleNotFoundError:
_SUPPORTS_DISTRIBUTED = False
MirroredStrategy = None
logger = logging.getLogger(__name__)
[docs]class DefaultTrainer(object):
"""Default trainer for medical semantic segmentation."""
[docs] def __init__(self, cfg: config.Config, run_eagerly=None, strategy=None):
"""
Args:
cfg (Config): An experiment config.
run_eagerly (bool, optional): If `True`, runs eagerly.
Only available in tensorflow>=2.0.
strategy (tf.distribute.Strategy, optional): The strategy to use for training.
Only available if `tf.distribute` package is available (tensorflow>=1.14).
"""
self._cfg = cfg
self._loss_history = None
self._default_strategy = (
tf.distribute.get_strategy() if _SUPPORTS_DISTRIBUTED else dl_utils.NoOpStrategy()
)
num_gpus = dl_utils.num_gpus()
if not env.is_tf2() and run_eagerly is not None:
warnings.warn("`run_eagerly` can only be specified in Tensorflow >2.0. " "Ignoring...")
run_eagerly = None
self._run_eagerly = run_eagerly
if strategy is None:
strategy = self._default_strategy
if _SUPPORTS_DISTRIBUTED and num_gpus > 1:
logger.info("Running multi gpu model")
strategy = MirroredStrategy()
self.strategy = strategy
with self.strategy.scope():
model = self.build_model(cfg)
if cfg.INIT_WEIGHTS:
self._init_model(model)
plot_model(model, to_file=os.path.join(cfg.OUTPUT_DIR, "model.png"), show_shapes=True)
model.summary(line_length=120, print_fn=lambda x: logger.info(x))
model_json = model.to_json()
model_json_save_path = os.path.join(cfg.OUTPUT_DIR, "model.json")
with open(model_json_save_path, "w") as json_file:
json_file.write(model_json)
# Replicate model on multiple gpus when tensorflow.distribute module not available.
# Note this does not solve issue of having too large of a model
if not _SUPPORTS_DISTRIBUTED and num_gpus > 1:
logger.info("Running multi gpu model")
model = dl_utils.ModelMGPU(model, gpus=num_gpus)
self._train_loader, self._val_loader = self._build_data_loaders(cfg)
self._model = model
[docs] def train(self):
"""Train model specified by config.
Do not call this under a strategy scope. Instead, set `self.strategy`.
"""
cfg = self._cfg
with self.strategy.scope():
self._train_model()
if cfg.TEST_DATASET:
# Specialized strategies are not currently supported for testing.
if _SUPPORTS_DISTRIBUTED and not isinstance(
self.strategy, (dl_utils.NoOpStrategy, type(self._default_strategy))
):
logger.error(
f"Strategy '{type(self.strategy).__name__}' not currently "
f"supported for testing. "
f"Please run testing separately on a single gpu."
)
return {}
return self.test(cfg, self._model)
else:
return {}
def _init_model(self, model):
"""Initialize model with weights and apply any freezing necessary."""
cfg = self._cfg
logger.info("Loading weights from {}".format(cfg.INIT_WEIGHTS))
model.load_weights(cfg.INIT_WEIGHTS)
frozen_layers = cfg.FREEZE_LAYERS
if frozen_layers:
fl = range(frozen_layers[0], frozen_layers[1])
logger.info("Freezing layers [{}, {})".format(fl.start, fl.stop))
for i in fl:
model.layers[i].trainable = False
def build_callbacks(self):
cfg = self._cfg
output_dir = cfg.OUTPUT_DIR
callbacks = []
if cfg.LR_SCHEDULER_NAME:
callbacks.append(solver.build_lr_scheduler(cfg))
if cfg.USE_EARLY_STOPPING:
callbacks.append(
kc.EarlyStopping(
monitor=cfg.EARLY_STOPPING_CRITERION,
min_delta=cfg.EARLY_STOPPING_MIN_DELTA,
patience=cfg.EARLY_STOPPING_PATIENCE,
)
)
self._loss_history = LossHistory()
tb_kwargs = dict(update_freq="batch") if env.is_tf2() else {}
callbacks.extend(
[
kc.ModelCheckpoint(
os.path.join(output_dir, "weights.{epoch:03d}-{val_loss:.4f}.h5"),
save_best_only=True,
save_weights_only=True,
),
kc.TensorBoard(output_dir, write_grads=False, write_images=False, **tb_kwargs),
WandBLogger() if env.supports_wandb() else None,
kc.CSVLogger(os.path.join(output_dir, "metrics.log")),
self._loss_history,
]
)
callbacks = [x for x in callbacks if x is not None]
return callbacks
[docs] def build_loss(self):
"""Builds loss function used with ``model.compile(loss=...)``.
"""
return build_loss(self._cfg)
def _train_model(self):
"""Train model.
If multi-gpu training and distributed training is supported (tensorflow>=1.15),
call this function with the appropriate strategy scope::
with self.strategy.scope():
self._train_model()
"""
cfg = self._cfg
n_epochs = cfg.N_EPOCHS
num_workers = cfg.NUM_WORKERS
output_dir = cfg.OUTPUT_DIR
model = self._model
# TODO: Add more options for metrics.
optimizer = solver.build_optimizer(cfg)
loss_func = self.build_loss()
metrics = [lr_callback(optimizer), dice_loss]
callbacks = self.build_callbacks()
if isinstance(loss_func, kc.Callback):
callbacks.insert(0, loss_func)
metrics.append(loss_func.criterion)
model.compile(optimizer=optimizer, loss=loss_func, metrics=metrics)
if env.is_tf2():
run_eagerly = tf.executing_eagerly() if self._run_eagerly is None else self._run_eagerly
model.run_eagerly = run_eagerly
train_loader, val_loader = self._train_loader, self._val_loader
use_multiprocessing = num_workers > 1
# Start training
model.fit_generator(
train_loader,
epochs=n_epochs,
validation_data=val_loader,
callbacks=callbacks,
workers=num_workers,
use_multiprocessing=use_multiprocessing,
verbose=1,
shuffle=False,
)
# Save optimizer state
io_utils.save_optimizer(model.optimizer, output_dir)
# Save files to write as output
# TODO: refactor to save dataframe.
hist_cb = self._loss_history
data = [hist_cb.epoch, hist_cb.losses, hist_cb.val_losses]
pik_data_path = os.path.join(output_dir, "pik_data.dat")
with open(pik_data_path, "wb") as f:
pickle.dump(data, f)
@classmethod
def test(cls, cfg: config.Config, model):
logger.info("Beginning testing...")
cfg = copy.deepcopy(cfg) # will be modified below.
cfg.change_to_test()
weights = cfg.TEST_WEIGHT_PATH
if not cfg.TEST_WEIGHT_PATH:
weights = dl_utils.get_weights(cfg.OUTPUT_DIR)
logger.info("Best weights: {}".format(weights))
model.load_weights(weights)
test_dataset = cfg.TEST_DATASET
test_gen = cls.build_test_data_loader(cfg)
evaluator = build_evaluator(test_dataset, cfg, save_raw_data=True)
return inference_on_dataset(model, test_gen, evaluator)
@classmethod
def build_model(cls, cfg):
try:
return build_model(cfg)
except KeyError:
# TODO (TF2.X)
if env.is_tf2():
raise ValueError(
"`get_model` not currently supported for tf2. "
"We are working on backwards compatibility"
)
from medsegpy.modeling import get_model
return get_model(cfg)
def _build_data_loaders(self, cfg) -> Tuple[im_gens.Generator, im_gens.Generator]:
"""Builds train and val data loaders."""
train_loader = build_loader(
cfg,
dataset_names=cfg.TRAIN_DATASET,
batch_size=cfg.TRAIN_BATCH_SIZE,
drop_last=True,
is_test=False,
shuffle=True,
)
val_loader = build_loader(
cfg,
dataset_names=cfg.VAL_DATASET,
batch_size=cfg.VALID_BATCH_SIZE,
drop_last=True,
is_test=False,
shuffle=False,
)
return train_loader, val_loader
@classmethod
def build_test_data_loader(cls, cfg):
return build_loader(
cfg,
dataset_names=cfg.TEST_DATASET,
batch_size=cfg.TEST_BATCH_SIZE,
drop_last=False,
is_test=True,
shuffle=False,
)