In this example, we build the original U-Net (Ronneberger et al., 2016) in TensorFlow and train it to segment our platelet data into 7 classes. See (Guay et al., 2019) for more details.
import logging
import os
import random
import sys
import time
import matplotlib.pyplot as plt
import tensorflow as tf
# Utilities for this example, abstracting away a messy in-lab codebase
from src.demos import demo_data, demo_segmentation
from src.download import download_data_if_missing
from src.initialize_instance import initialize_instance
from src.create_network import create_network
from src.train_and_eval import train_and_eval
from src.segment import segment
tf.logging.set_verbosity(tf.logging.INFO)
logger_level = logging.INFO
download_dir = os.path.expanduser('~/examples/data')
output_dir = os.path.expanduser('~/examples/example3/output')
If the data already exists in the download folder, the download will be skipped.
download_data_if_missing(download_dir)
# Dir containing train/eval/test images
data_dir = os.path.join(download_dir, 'platelet_data')
demo_data(data_dir)
# Original U-Net input size
input_shape = [572, 572]
# Number of training epochs
n_epochs = 20
# Early stopping criterion. Specify a metric ('mean_iou' or 'adj_rand_idx')
# a threshold for the metric, and the epoch to begin testing the criterion
stop_criterion = ('mean_iou', 0.3, 10)
# Save directory for this instance of the example
save_dir = os.path.join(output_dir, time.strftime('%m%d'))
# Trainable weight initialization RNG seed
weight_seed = 12345
# Training data presentation order RNG seed
data_seed = 2468
# 2D training windows are taken from a larger 3D volume, which is
# divided up into overlapping training-window-size regions with
# top-left corners spaced (approximately) this far apart:
train_window_spacing = [1, 80, 80] # (z, x, y) order
# U-net settings
net_settings = {
'spatial_mode': 0, # 2D ops only
'n_blocks': 4, # num convolution blocks in the encoder and decoder
'n_comps': 2, # num convolutions per block
'n_kernels': 64, # num convolution kernels in first encoder block
'log_gamma2': -5, # log10 of L2 regularization on weights (ignored if < -10)
'log_gamma1': -11, # log10 of L1 regularization on weights (ignored if < -10)
'padding_type': 'valid', # Either 'valid' or 'same'. Original U-Net used valid
'pooling_type': 'maxpool' # Either 'maxpool' or 'conv'. Original U-Net used maxpool
}
# Optimization settings (ADAM optimizer: https://arxiv.org/pdf/1412.6980.pdf)
optim_settings = {
'log_learning_rate': -3.3, # log10 of learning rate
'log_alpha1': -1.5, # log10 of alpha1 := 1 - beta1
'log_alpha2': -2.1, # log10 of alpha2 := 1 - beta2
'log_epsilon': -7., # log10 of epsilon
'weight_floor': 0.01, # minimum weight value for weighted cross-entropy loss
'exponential_decay_rate': 1, # exponential decay applied to learning rate
'log_decay_steps': 10 # log10 of exponential decay step count
}
# Define for saving settings during archival
instance_settings = {
'data_dir': data_dir,
'instance_dir': None, # Defined in `initialize_instance()`
'input_shape': input_shape,
'n_epochs': n_epochs,
'stop_criterion': stop_criterion,
'weight_seed': weight_seed,
'data_seed': data_seed,
'train_window_spacing': train_window_spacing,
'net_settings': net_settings,
'optim_settings': optim_settings
}
Create a folder within save_dir
for the output of this instance's run, and archive
all example source code.
# Updates key 'instance_dir' in `instance_settings`, containing the directory within `save_dir`
# where this example instance's output will be saved
instance_settings, logger = initialize_instance(save_dir, instance_settings, logger_level)
Create a network object - a wrapper around a TensorFlow computation graph that implements training, evaluation, and inference with a U-Net - as well as settings dicts for the training and evaluation processes.
net, train_settings, eval_settings = create_network(instance_settings)
Note: With default settings and an NVIDIA Tesla P100 GPU, this takes about 1.7 hours.
# Return the trained network
net = train_and_eval(net, train_settings, eval_settings, logger)
# Set to something to save segmentation images to disk there
output_dir = None
test_seg, test_probs = segment(
net_sources=net,
image_source=os.path.join(data_dir, 'test-images.tif'),
label_source=os.path.join(data_dir, 'test-labels.tif'),
output_dir=output_dir)
fig, ax = plt.subplots(1, 1, figsize=(400/110, 609/110))
ax.imshow(test_seg[75], cmap='jet', vmin=0, vmax=6)
ax.set_title('Test prediction')
h = ax.axis('off')
Comparisons with ground truth labels, along with visualizations of the inference probability maps for each class
demo_segmentation(net, data_dir)