Run Hooks


SessionRunHooks are useful to track training, report progress, request early stopping and more. Users can attach an arbitrary number of hooks to an estimator. SessionRunHooks use the observer pattern and notify at the following points:

A SessionRunHook encapsulates a piece of reusable/composable computation that can piggyback a call to A hook can add any ops-or-tensor/feeds to the run call, and when the run call finishes with success gets the outputs it requested. Hooks are allowed to add ops to the graph in hook.begin(). The graph is finalized after the begin() method is called.

Built-in Run Hooks

There are a few pre-defined SessionRunHooks available, for example: Run hooks are useful for tracking training, reporting progress, requesting early stopping, and more. Users can attach an arbitrary number of hooks to an estimator. Some built-in run hooks include:

Method Description
hook_checkpoint_saver() Saves checkpoints every N steps or seconds.
hook_global_step_waiter() Delay execution until global step reaches to wait_until_step.
hook_history_saver() Saves Metrics History.
hook_logging_tensor() Prints the given tensors once every N local steps or once every N seconds.
hook_nan_tensor() NaN Loss monitor.
hook_progress_bar() Creates and updates progress bar.
hook_step_counter() Steps per second monitor.
hook_stop_at_step() Monitor to request stop at a specified step.
hook_summary_saver() Saves summaries every N steps.

For example, we can use hook_progress_bar() to attach a hook to create and update a progress bar during the model training process.

fcs <- feature_columns(column_numeric("drat"))
input <- input_fn(mtcars, response = "mpg", features = c("drat", "cyl"), batch_size = 8L)
lr <- linear_regressor(
  feature_columns = fcs
) %>% train(
  input_fn = input,
  steps = 2,
  hooks = list(
Training 2/2 [======================================] - ETA:  0s - loss: 3136.10

Another example is to use hook_history_saver() to save the training history every 2 training steps like the following:

lr <- linear_regressor(feature_columns = fcs) 
training_history <- train(
  input_fn = input,
  steps = 4,
  hooks = list(
    hook_history_saver(every_n_step = 2)

train() returns the saved training metrics history:

> training_history
  mean_losses total_losses steps
1    343.9690     2751.752     2
2    419.7618     3358.094     4

Custom Run Hooks

Users can also create custom run hooks by defining the behaviors of the hook in different phases of a session.

We can implement a custom run hook by defining a list of call back functions as part of session_run_hook() initialization. It has the following optional parameters that can be overriden by a custom defined function:

For example, let’s try to implement the hook_history_saver() that we showed in previous section. We first initialize a iter_count variable to save the current count of the steps being run. We increment the count as part of after_run() after each calls. Inside before_run(), we use the context to access the current losses and save it to a tensor named “losses” so that later we can access it inside after_run() via values$results$losses that contains the evaluated value of tensor “losses”. Finally, we calculate the mean of the raw losses and append it to a global state varibale named “mean_losses_history” with the list of mean losses.

mean_losses_history <<- NULL
hook_history_saver_custom <- function(every_n_step) {

    iter_count <<- 0


      before_run = function(context) {
          losses = context$session$graph$get_collection("losses")
      after_run = function(context, values) {
        iter_count <<- iter_count + 1
        print(paste0("Running step: ", iter_count))
        if (iter_count %% every_n_step == 0) {
          raw_losses <- values$results$losses[[1]]
          mean_losses_history <<- c(mean_losses_history, mean(raw_losses))

Next, we can then attach this hook to our estimator:

lr <- linear_regressor(
  feature_columns = fcs
) %>% train(
  input_fn = input,
  steps = 4,
  hooks = list(
    hook_history_saver_custom(every_n_step = 1)
[1] "Running step: 1"
[1] "Running step: 2"
[1] "Running step: 3"
[1] "Running step: 4"

We saved the losses history at every step. Let’s check the list of losses:

> mean_losses_history
[1] 415.8088 452.2128 376.7346 331.6045