Source code for medsegpy.engine.callbacks

import logging
from copy import deepcopy

from keras import callbacks as kc

from medsegpy.utils import env

try:
    import wandb
    import wandb.wandb_run

    _WANDB_AVAILABLE = True
except ImportError:  # pragma: no-cover
    wandb = None
    _WANDB_AVAILABLE = False


__all__ = ["lr_callback", "LossHistory", "WandBLogger"]

logger = logging.getLogger(__name__)


[docs]def lr_callback(optimizer): """Wrapper for learning rate tensorflow metric. Args: optimizer: Optimizer used for training. Returns: func: To be wrapped in metric or callback. """ def lr(y_true, y_pred): return optimizer.lr
return lr
[docs]class LossHistory(kc.Callback): """A Keras callback to log training history""" def on_train_begin(self, logs=None): self.val_losses = [] self.losses = [] self.epoch = [] def on_epoch_end(self, epoch, logs=None): self.val_losses.append(logs.get("val_loss", float("nan"))) self.losses.append(logs.get("loss")) self.epoch.append(epoch + 1) metrics = " - ".join( [ "{}: {:0.4f}".format(k, v) if v >= 1e-3 else "{}: {:0.4e}".format(k, v) for k, v in logs.items() ] )
logger.info("Epoch {} - {}".format(epoch + 1, metrics))
[docs]class WandBLogger(kc.Callback): """A Keras callback to log to weights and biases. Currently only supports logging scalars. """
[docs] def __init__(self, period: int = 20, experiment="auto", **kwargs): """ Args: period (int, optional): Logging period. experiment (`wandb.wandb_run.Run` | `str` | `None`): The experiment run. If ``"auto"``, a run will only be created if ``wandb.run`` is None. If ``None``, a run will be created. **kwargs: Options to pass to ``wandb.init()`` to create run. Ignored if ``experiment`` specified. """ if not env.supports_wandb(): raise ValueError( "Weights & Biases is not supported. " "Install package via `pip install wandb`. " "See documentation https://docs.wandb.com/ " ) assert isinstance(experiment, wandb.wandb_run.Run) or experiment in ("auto", None) if (not wandb.run and experiment == "auto") or experiment is None: wandb.init(**kwargs) assert isinstance(period, int) and period > 0, "`period` must be int >0"
self._period = period def on_train_begin(self, logs=None): self._step = 0 def on_batch_end(self, batch_idx, logs=None): self._step += 1 if not logs or self._step % self._period != 0: return wandb.log(logs, step=self._step) def on_epoch_end(self, epoch, logs=None): logs = deepcopy(logs) logs["epoch"] = epoch
wandb.log(logs, step=self._step)