Example 3: U-Net training

In this example, we build the original U-Net (Ronneberger et al., 2016) in TensorFlow and train it to segment our platelet data into 7 classes. See (Guay et al., 2019) for more details.


Setup

In [1]:
import logging
import os
import random
import sys
import time

import matplotlib.pyplot as plt
import tensorflow as tf

# Utilities for this example, abstracting away a messy in-lab codebase
from src.demos import demo_data, demo_segmentation
from src.download import download_data_if_missing
from src.initialize_instance import initialize_instance
from src.create_network import create_network
from src.train_and_eval import train_and_eval
from src.segment import segment

tf.logging.set_verbosity(tf.logging.INFO)
logger_level = logging.INFO
WARNING:tensorflow:From src/_private/genenet/GeneNet.py:183: The name tf.train.SessionRunHook is deprecated. Please use tf.estimator.SessionRunHook instead.

Data download and example output directories

In [2]:
download_dir = os.path.expanduser('~/examples/data')
output_dir = os.path.expanduser('~/examples/example3/output')

Download data

If the data already exists in the download folder, the download will be skipped.

In [3]:
download_data_if_missing(download_dir)

# Dir containing train/eval/test images
data_dir = os.path.join(download_dir, 'platelet_data')
0it [00:00, ?it/s]
Downloading https://www.dropbox.com/s/68yclbraqq1diza/platelet_data_1219.zip?dl=1 to /home/matthew/examples/unet/data/platelet_data_1219.zip
100%|█████████▉| 188661760/188671885 [00:13<00:00, 16678806.27it/s]
Extracting /home/matthew/examples/unet/data/platelet_data_1219.zip to /home/matthew/examples/unet/data
Finished

Data demo

In [4]:
demo_data(data_dir)
188678144it [00:29, 16678806.27it/s]                               

Training settings

In [5]:
# Original U-Net input size
input_shape = [572, 572]

# Number of training epochs
n_epochs = 20

# Early stopping criterion. Specify a metric ('mean_iou' or 'adj_rand_idx')
# a threshold for the metric, and the epoch to begin testing the criterion
stop_criterion = ('mean_iou', 0.3, 10)

# Save directory for this instance of the example
save_dir = os.path.join(output_dir, time.strftime('%m%d'))

# Trainable weight initialization RNG seed
weight_seed = 12345

# Training data presentation order RNG seed
data_seed = 2468

# 2D training windows are taken from a larger 3D volume, which is
# divided up into overlapping training-window-size regions with
# top-left corners spaced (approximately) this far apart:
train_window_spacing = [1, 80, 80]  # (z, x, y) order

# U-net settings
net_settings = {
    'spatial_mode': 0,          # 2D ops only
    'n_blocks': 4,              # num convolution blocks in the encoder and decoder
    'n_comps': 2,               # num convolutions per block
    'n_kernels': 64,            # num convolution kernels in first encoder block
    'log_gamma2': -5,           # log10 of L2 regularization on weights (ignored if < -10)
    'log_gamma1': -11,          # log10 of L1 regularization on weights (ignored if < -10)
    'padding_type': 'valid',    # Either 'valid' or 'same'. Original U-Net used valid
    'pooling_type': 'maxpool'   # Either 'maxpool' or 'conv'. Original U-Net used maxpool
}

# Optimization settings (ADAM optimizer: https://arxiv.org/pdf/1412.6980.pdf)
optim_settings = {
    'log_learning_rate': -3.3,    # log10 of learning rate
    'log_alpha1': -1.5,           # log10 of alpha1 := 1 - beta1
    'log_alpha2': -2.1,           # log10 of alpha2 := 1 - beta2
    'log_epsilon': -7.,           # log10 of epsilon
    'weight_floor': 0.01,         # minimum weight value for weighted cross-entropy loss
    'exponential_decay_rate': 1,  # exponential decay applied to learning rate
    'log_decay_steps': 10         # log10 of exponential decay step count
}
In [6]:
# Define for saving settings during archival
instance_settings = {
    'data_dir': data_dir,
    'instance_dir': None,  # Defined in `initialize_instance()`
    'input_shape': input_shape,
    'n_epochs': n_epochs,
    'stop_criterion': stop_criterion,
    'weight_seed': weight_seed,
    'data_seed': data_seed,
    'train_window_spacing': train_window_spacing,
    'net_settings': net_settings,
    'optim_settings': optim_settings
}

Network training

Initialization

Create a folder within save_dir for the output of this instance's run, and archive all example source code.

In [7]:
# Updates key 'instance_dir' in `instance_settings`, containing the directory within `save_dir`
# where this example instance's output will be saved
instance_settings, logger = initialize_instance(save_dir, instance_settings, logger_level)

Network creation

Create a network object - a wrapper around a TensorFlow computation graph that implements training, evaluation, and inference with a U-Net - as well as settings dicts for the training and evaluation processes.

In [ ]:
net, train_settings, eval_settings = create_network(instance_settings)

Network training and evaluation

Note: With default settings and an NVIDIA Tesla P100 GPU, this takes about 1.7 hours.

In [ ]:
# Return the trained network
net = train_and_eval(net, train_settings, eval_settings, logger)

Segmentation with a trained network

Using a net to segment an image

In [10]:
# Set to something to save segmentation images to disk there
output_dir = None

test_seg, test_probs = segment(
    net_sources=net,
    image_source=os.path.join(data_dir, 'test-images.tif'),
    label_source=os.path.join(data_dir, 'test-labels.tif'),
    output_dir=output_dir)
01/03/2020 07:44:39 - genenet.DataHandler - INFO: Loaded eval data from /home/matthew/examples/unet/data/platelet_data
01/03/2020 07:44:39 - genenet.DataHandler - INFO: Seeded DataHandler.random_state with 601997554
Net input shape: [572, 572]. Net output shape: [388, 388]
[7, 121, 609, 400]
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Summary name fifo_queue_DequeueUpTo:2_b0_f0 is illegal; using fifo_queue_DequeueUpTo_2_b0_f0 instead.
INFO:tensorflow:Summary name classes/classes:0_b0_f0 is illegal; using classes/classes_0_b0_f0 instead.
INFO:tensorflow:Summary name probabilities/probabilities:0_b0_f0 is illegal; using probabilities/probabilities_0_b0_f0 instead.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /home/matthew/examples/unet/output/0103/0/model/model.ckpt-36000
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.

Visualizing that segmentation

In [11]:
fig, ax = plt.subplots(1, 1, figsize=(400/110, 609/110))

ax.imshow(test_seg[75], cmap='jet', vmin=0, vmax=6)
ax.set_title('Test prediction')
h = ax.axis('off')

Demo of network segmentation on train and eval data

Comparisons with ground truth labels, along with visualizations of the inference probability maps for each class

In [13]:
demo_segmentation(net, data_dir)
01/03/2020 13:18:50 - genenet.DataHandler - INFO: Loaded eval data from /home/matthew/examples/unet/data/platelet_data
01/03/2020 13:18:50 - genenet.DataHandler - INFO: Seeded DataHandler.random_state with 619565160
Net input shape: [572, 572]. Net output shape: [388, 388]
[7, 50, 800, 800]
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Summary name fifo_queue_DequeueUpTo:2_b0_f0 is illegal; using fifo_queue_DequeueUpTo_2_b0_f0 instead.
INFO:tensorflow:Summary name classes/classes:0_b0_f0 is illegal; using classes/classes_0_b0_f0 instead.
INFO:tensorflow:Summary name probabilities/probabilities:0_b0_f0 is illegal; using probabilities/probabilities_0_b0_f0 instead.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /home/matthew/examples/unet/output/0103/0/model/model.ckpt-36000
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
01/03/2020 13:19:24 - genenet.DataHandler - INFO: Loaded eval data from /home/matthew/examples/unet/data/platelet_data
01/03/2020 13:19:24 - genenet.DataHandler - INFO: Seeded DataHandler.random_state with 482580065
Net input shape: [572, 572]. Net output shape: [388, 388]
[7, 24, 800, 800]
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Summary name fifo_queue_DequeueUpTo:2_b0_f0 is illegal; using fifo_queue_DequeueUpTo_2_b0_f0 instead.
INFO:tensorflow:Summary name classes/classes:0_b0_f0 is illegal; using classes/classes_0_b0_f0 instead.
INFO:tensorflow:Summary name probabilities/probabilities:0_b0_f0 is illegal; using probabilities/probabilities_0_b0_f0 instead.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /home/matthew/examples/unet/output/0103/0/model/model.ckpt-36000
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.