in mesh_tensorflow/experimental/unet.py [0:0]
def unet_with_spatial_partition(mesh, mesh_impl, dataset_str, images, labels):
"""Builds the UNet model graph, train op and eval metrics.
Args:
mesh: a MeshTensorflow.mesh object.
mesh_impl: a mesh implementation, such as SimdMeshImpl and
PlacementMeshImpl.
dataset_str: a string of either train or eval. This is used for batch_norm.
images: a laid out Tensor with shape [batch, x, y, num_channels]
or [batch, x, y, z, num_channels].
labels: a laid out Tensor with shape [batch, x, y, num_classes]
or [batch, x, y, z, num_classes].
Returns:
Prediction and loss.
"""
is_training = (dataset_str == 'train')
if dataset_str == 'train':
batch_dim = mtf.Dimension('batch', FLAGS.batch_size_train)
else:
assert dataset_str == 'eval'
batch_dim = mtf.Dimension('batch', FLAGS.batch_size_eval)
image_nx_dim = mtf.Dimension('image_nx_block', FLAGS.image_nx_block)
image_ny_dim = mtf.Dimension('image_ny_block', FLAGS.image_ny_block)
image_sx_dim = mtf.Dimension('image_sx_block',
FLAGS.ct_resolution // FLAGS.image_nx_block)
image_sy_dim = mtf.Dimension('image_sy_block',
FLAGS.ct_resolution // FLAGS.image_ny_block)
image_sz_dim = mtf.Dimension('image_sz_block', FLAGS.ct_resolution)
image_c_dim = mtf.Dimension('image_c', FLAGS.image_c)
label_c_dim = mtf.Dimension('label_c', FLAGS.label_c)
mtf_images_shape, mtf_labels_shape = get_input_mtf_shapes(dataset_str)
mtf_dtype = tf.as_dtype(FLAGS.mtf_dtype)
variable_dtype = mtf.VariableDType(mtf_dtype, mtf_dtype, mtf_dtype)
# Import input features.
x = mtf.import_laid_out_tensor(
mesh,
mesh_impl.LaidOutTensor(images),
mtf_images_shape)
x = mtf.cast(x, mtf_dtype)
# Import ground truth labels.
t = mtf.import_laid_out_tensor(
mesh,
mesh_impl.LaidOutTensor(labels),
mtf_labels_shape)
t = mtf.cast(t, mtf_dtype)
# Transpose the blocks.
if FLAGS.sampled_2d_slices:
x = mtf.transpose(x, [batch_dim,
image_nx_dim, image_ny_dim,
image_sx_dim, image_sy_dim,
image_c_dim])
t = mtf.transpose(t, [batch_dim,
image_nx_dim, image_ny_dim,
image_sx_dim, image_sy_dim,
label_c_dim])
else:
x = mtf.transpose(x, [batch_dim,
image_nx_dim, image_ny_dim,
image_sx_dim, image_sy_dim,
image_sz_dim, image_c_dim])
t = mtf.transpose(t, [batch_dim,
image_nx_dim, image_ny_dim,
image_sx_dim, image_sy_dim,
image_sz_dim, label_c_dim])
# Network.
levels = []
all_bn_update_ops = []
# add levels with convolution or down-sampling
for depth in range(FLAGS.network_depth):
for n_conv in range(FLAGS.n_conv_per_block):
if depth == 0 and n_conv == 0:
# no dropout in 1st layer.
dropout_keep_p = 1.0
else:
dropout_keep_p = FLAGS.dropout_keep_p
x, bn_update_ops = conv_with_spatial_partition(
x, FLAGS.sampled_2d_slices,
image_nx_dim, image_ny_dim,
FLAGS.n_base_filters * (2**depth),
dropout_keep_p,
FLAGS.with_batch_norm,
is_training,
'conv_{}_{}'.format(depth, n_conv),
variable_dtype,
'conv_down_{}_{}'.format(depth, n_conv))
all_bn_update_ops.extend(bn_update_ops)
levels.append(x)
if depth < FLAGS.network_depth - 1:
if FLAGS.sampled_2d_slices:
x = mtf.layers.max_pool2d(x, ksize=(2, 2))
else:
x = mtf.layers.max_pool3d(x, ksize=(2, 2, 2))
# add levels with up-convolution or up-sampling
for depth in range(FLAGS.network_depth - 1)[::-1]:
x = deconv_with_spatial_partition(
x, FLAGS.sampled_2d_slices, image_nx_dim, image_ny_dim,
FLAGS.n_base_filters * (2**depth),
FLAGS.dropout_keep_p,
is_training,
'conv_{}_{}'.format(depth, FLAGS.n_conv_per_block - 1),
variable_dtype, 'deconv_{}_0'.format(depth))
x = mtf.concat(
[x, levels[depth]],
concat_dim_name='conv_{}_{}'.format(depth, FLAGS.n_conv_per_block - 1))
for n_conv in range(FLAGS.n_conv_per_block):
x, bn_update_ops = conv_with_spatial_partition(
x, FLAGS.sampled_2d_slices,
image_nx_dim, image_ny_dim,
FLAGS.n_base_filters * (2**depth),
FLAGS.dropout_keep_p,
FLAGS.with_batch_norm,
is_training,
'conv_{}_{}'.format(depth, n_conv),
variable_dtype,
'conv_up_{}_{}'.format(depth, n_conv))
all_bn_update_ops.extend(bn_update_ops)
# no dropout in the final layer.
if FLAGS.sampled_2d_slices:
y = mtf.layers.conv2d_with_blocks(
x, mtf.Dimension('label_c', FLAGS.label_c),
filter_size=(1, 1), strides=(1, 1), padding='SAME',
h_blocks_dim=image_nx_dim, w_blocks_dim=image_ny_dim,
variable_dtype=variable_dtype,
name='final_conv_{}'.format(FLAGS.label_c),
)
else:
y = mtf.layers.conv3d_with_blocks(
x, mtf.Dimension('label_c', FLAGS.label_c),
filter_size=(1, 1, 1), strides=(1, 1, 1), padding='SAME',
d_blocks_dim=image_nx_dim, h_blocks_dim=image_ny_dim,
variable_dtype=variable_dtype,
name='final_conv_{}'.format(FLAGS.label_c),
)
# use mtf.constant to make sure there is no CPU-side constants.
def scalar(v, dtype):
return mtf.constant(mesh, v, shape=[], dtype=dtype)
argmax_t = mtf.argmax(t, label_c_dim)
liver_t = mtf.cast(mtf.equal(argmax_t, scalar(1, tf.int32)), mtf_dtype)
lesion_t = mtf.cast(mtf.equal(argmax_t, scalar(2, tf.int32)), mtf_dtype)
argmax_y = mtf.argmax(y, label_c_dim)
lesion_y = mtf.cast(mtf.equal(argmax_y, scalar(2, tf.int32)), mtf_dtype)
# summary of class ratios.
lesion_pred_ratio = mtf.reduce_mean(lesion_y)
lesion_label_ratio = mtf.reduce_mean(lesion_t)
# summary of accuracy.
accuracy = mtf.reduce_mean(mtf.cast(mtf.equal(argmax_y, argmax_t), mtf_dtype))
# Cross-entropy loss. Up-weight the liver region.
pixel_loss = mtf.layers.softmax_cross_entropy_with_logits(y, t, label_c_dim)
pixel_weight = scalar(1, mtf_dtype) + \
liver_t * scalar(FLAGS.xen_liver_weight - 1, mtf_dtype) + \
lesion_t * scalar(FLAGS.xen_lesion_weight - FLAGS.xen_liver_weight,
mtf_dtype)
loss_xen = mtf.reduce_mean(pixel_loss * pixel_weight)
# Dice loss
y_prob = mtf.softmax(y, reduced_dim=label_c_dim)
lesion_prob = mtf.reduce_sum(mtf.slice(y_prob, 2, 1, 'label_c'),
reduced_dim=mtf.Dimension('label_c', 1))
prob_intersect = mtf.reduce_sum(lesion_prob * lesion_t,
output_shape=mtf.Shape([batch_dim]))
prob_area_sum = mtf.reduce_sum(lesion_prob + lesion_t,
output_shape=mtf.Shape([batch_dim]))
loss_dice_per_case = mtf.reduce_mean(
scalar(-2, mtf_dtype) * prob_intersect / (
prob_area_sum + scalar(FLAGS.dice_epsilon, mtf_dtype)))
loss_dice_global = scalar(-2, mtf_dtype) * mtf.reduce_sum(prob_intersect) / (
mtf.reduce_sum(prob_area_sum) + scalar(FLAGS.dice_epsilon, mtf_dtype))
loss_dice = (loss_dice_per_case + loss_dice_global) * scalar(0.5, mtf_dtype)
loss = scalar(FLAGS.dice_loss_weight, mtf_dtype) * loss_dice + scalar(
1 - FLAGS.dice_loss_weight, mtf_dtype) * loss_xen
intersect = mtf.reduce_sum(lesion_y * lesion_t,
output_shape=mtf.Shape([batch_dim]))
area_sum = mtf.reduce_sum(lesion_y + lesion_t,
output_shape=mtf.Shape([batch_dim]))
# summary of dice.
dice_per_case = mtf.reduce_mean(scalar(2, mtf_dtype) * intersect / (
area_sum + scalar(0.000001, mtf_dtype)))
dice_global = scalar(2, mtf_dtype) * mtf.reduce_sum(intersect) / (
mtf.reduce_sum(area_sum) + scalar(0.000001, mtf_dtype))
eval_metrics = {
'lesion_pred_ratio': lesion_pred_ratio,
'lesion_label_ratio': lesion_label_ratio,
'accuracy_of_all_classes': accuracy,
'lesion_dice_per_case': dice_per_case,
'lesion_dice_global': dice_global,
'loss_xen': loss_xen,
'loss_dice': loss_dice,
'loss_dice_per_case': loss_dice_per_case,
'loss_dice_global': loss_dice_global,
}
if FLAGS.sampled_2d_slices:
y_prob_downsampled = mtf.layers.avg_pool2d(
y_prob, ksize=(FLAGS.pred_downsample,) * 2)
if FLAGS.output_ground_truth:
lesion_gt_downsampled = mtf.layers.avg_pool2d(
mtf.slice(t, 2, 1, 'label_c'), ksize=(FLAGS.pred_downsample,) * 2)
else:
y_prob_downsampled = mtf.layers.avg_pool3d(
y_prob, ksize=(FLAGS.pred_downsample,) * 3)
if FLAGS.output_ground_truth:
lesion_gt_downsampled = mtf.layers.avg_pool3d(
mtf.slice(t, 2, 1, 'label_c'), ksize=(FLAGS.pred_downsample,) * 3)
liver_prob_downsampled = mtf.slice(y_prob_downsampled, 1, 1, 'label_c')
lesion_prob_downsampled = mtf.slice(y_prob_downsampled, 2, 1, 'label_c')
preds = [
mtf.reduce_sum(liver_prob_downsampled,
reduced_dim=mtf.Dimension('label_c', 1)),
mtf.reduce_sum(lesion_prob_downsampled,
reduced_dim=mtf.Dimension('label_c', 1))]
if FLAGS.output_ground_truth:
preds.append(mtf.reduce_sum(
lesion_gt_downsampled, reduced_dim=mtf.Dimension('label_c', 1)))
preds.extend([intersect, area_sum])
return preds, loss, eval_metrics, all_bn_update_ops