def unet_with_spatial_partition()

in mesh_tensorflow/experimental/unet.py [0:0]


def unet_with_spatial_partition(mesh, mesh_impl, dataset_str, images, labels):
  """Builds the UNet model graph, train op and eval metrics.

  Args:
    mesh: a MeshTensorflow.mesh object.
    mesh_impl: a mesh implementation, such as SimdMeshImpl and
      PlacementMeshImpl.
    dataset_str: a string of either train or eval. This is used for batch_norm.
    images: a laid out Tensor with shape [batch, x, y, num_channels]
      or [batch, x, y, z, num_channels].
    labels: a laid out Tensor with shape [batch, x, y, num_classes]
      or [batch, x, y, z, num_classes].

  Returns:
    Prediction and loss.
  """

  is_training = (dataset_str == 'train')
  if dataset_str == 'train':
    batch_dim = mtf.Dimension('batch', FLAGS.batch_size_train)
  else:
    assert dataset_str == 'eval'
    batch_dim = mtf.Dimension('batch', FLAGS.batch_size_eval)
  image_nx_dim = mtf.Dimension('image_nx_block', FLAGS.image_nx_block)
  image_ny_dim = mtf.Dimension('image_ny_block', FLAGS.image_ny_block)
  image_sx_dim = mtf.Dimension('image_sx_block',
                               FLAGS.ct_resolution // FLAGS.image_nx_block)
  image_sy_dim = mtf.Dimension('image_sy_block',
                               FLAGS.ct_resolution // FLAGS.image_ny_block)
  image_sz_dim = mtf.Dimension('image_sz_block', FLAGS.ct_resolution)
  image_c_dim = mtf.Dimension('image_c', FLAGS.image_c)
  label_c_dim = mtf.Dimension('label_c', FLAGS.label_c)
  mtf_images_shape, mtf_labels_shape = get_input_mtf_shapes(dataset_str)

  mtf_dtype = tf.as_dtype(FLAGS.mtf_dtype)
  variable_dtype = mtf.VariableDType(mtf_dtype, mtf_dtype, mtf_dtype)

  # Import input features.
  x = mtf.import_laid_out_tensor(
      mesh,
      mesh_impl.LaidOutTensor(images),
      mtf_images_shape)
  x = mtf.cast(x, mtf_dtype)

  # Import ground truth labels.
  t = mtf.import_laid_out_tensor(
      mesh,
      mesh_impl.LaidOutTensor(labels),
      mtf_labels_shape)
  t = mtf.cast(t, mtf_dtype)

  # Transpose the blocks.
  if FLAGS.sampled_2d_slices:
    x = mtf.transpose(x, [batch_dim,
                          image_nx_dim, image_ny_dim,
                          image_sx_dim, image_sy_dim,
                          image_c_dim])

    t = mtf.transpose(t, [batch_dim,
                          image_nx_dim, image_ny_dim,
                          image_sx_dim, image_sy_dim,
                          label_c_dim])
  else:
    x = mtf.transpose(x, [batch_dim,
                          image_nx_dim, image_ny_dim,
                          image_sx_dim, image_sy_dim,
                          image_sz_dim, image_c_dim])

    t = mtf.transpose(t, [batch_dim,
                          image_nx_dim, image_ny_dim,
                          image_sx_dim, image_sy_dim,
                          image_sz_dim, label_c_dim])

  # Network.
  levels = []
  all_bn_update_ops = []
  # add levels with convolution or down-sampling
  for depth in range(FLAGS.network_depth):
    for n_conv in range(FLAGS.n_conv_per_block):
      if depth == 0 and n_conv == 0:
        # no dropout in 1st layer.
        dropout_keep_p = 1.0
      else:
        dropout_keep_p = FLAGS.dropout_keep_p
      x, bn_update_ops = conv_with_spatial_partition(
          x, FLAGS.sampled_2d_slices,
          image_nx_dim, image_ny_dim,
          FLAGS.n_base_filters * (2**depth),
          dropout_keep_p,
          FLAGS.with_batch_norm,
          is_training,
          'conv_{}_{}'.format(depth, n_conv),
          variable_dtype,
          'conv_down_{}_{}'.format(depth, n_conv))
      all_bn_update_ops.extend(bn_update_ops)
    levels.append(x)

    if depth < FLAGS.network_depth - 1:
      if FLAGS.sampled_2d_slices:
        x = mtf.layers.max_pool2d(x, ksize=(2, 2))
      else:
        x = mtf.layers.max_pool3d(x, ksize=(2, 2, 2))

  # add levels with up-convolution or up-sampling
  for depth in range(FLAGS.network_depth - 1)[::-1]:
    x = deconv_with_spatial_partition(
        x, FLAGS.sampled_2d_slices, image_nx_dim, image_ny_dim,
        FLAGS.n_base_filters * (2**depth),
        FLAGS.dropout_keep_p,
        is_training,
        'conv_{}_{}'.format(depth, FLAGS.n_conv_per_block - 1),
        variable_dtype, 'deconv_{}_0'.format(depth))
    x = mtf.concat(
        [x, levels[depth]],
        concat_dim_name='conv_{}_{}'.format(depth, FLAGS.n_conv_per_block - 1))

    for n_conv in range(FLAGS.n_conv_per_block):
      x, bn_update_ops = conv_with_spatial_partition(
          x, FLAGS.sampled_2d_slices,
          image_nx_dim, image_ny_dim,
          FLAGS.n_base_filters * (2**depth),
          FLAGS.dropout_keep_p,
          FLAGS.with_batch_norm,
          is_training,
          'conv_{}_{}'.format(depth, n_conv),
          variable_dtype,
          'conv_up_{}_{}'.format(depth, n_conv))
      all_bn_update_ops.extend(bn_update_ops)

  # no dropout in the final layer.
  if FLAGS.sampled_2d_slices:
    y = mtf.layers.conv2d_with_blocks(
        x, mtf.Dimension('label_c', FLAGS.label_c),
        filter_size=(1, 1), strides=(1, 1), padding='SAME',
        h_blocks_dim=image_nx_dim, w_blocks_dim=image_ny_dim,
        variable_dtype=variable_dtype,
        name='final_conv_{}'.format(FLAGS.label_c),
    )
  else:
    y = mtf.layers.conv3d_with_blocks(
        x, mtf.Dimension('label_c', FLAGS.label_c),
        filter_size=(1, 1, 1), strides=(1, 1, 1), padding='SAME',
        d_blocks_dim=image_nx_dim, h_blocks_dim=image_ny_dim,
        variable_dtype=variable_dtype,
        name='final_conv_{}'.format(FLAGS.label_c),
    )

  # use mtf.constant to make sure there is no CPU-side constants.
  def scalar(v, dtype):
    return mtf.constant(mesh, v, shape=[], dtype=dtype)

  argmax_t = mtf.argmax(t, label_c_dim)
  liver_t = mtf.cast(mtf.equal(argmax_t, scalar(1, tf.int32)), mtf_dtype)
  lesion_t = mtf.cast(mtf.equal(argmax_t, scalar(2, tf.int32)), mtf_dtype)

  argmax_y = mtf.argmax(y, label_c_dim)
  lesion_y = mtf.cast(mtf.equal(argmax_y, scalar(2, tf.int32)), mtf_dtype)

  # summary of class ratios.
  lesion_pred_ratio = mtf.reduce_mean(lesion_y)
  lesion_label_ratio = mtf.reduce_mean(lesion_t)

  # summary of accuracy.
  accuracy = mtf.reduce_mean(mtf.cast(mtf.equal(argmax_y, argmax_t), mtf_dtype))

  # Cross-entropy loss. Up-weight the liver region.
  pixel_loss = mtf.layers.softmax_cross_entropy_with_logits(y, t, label_c_dim)
  pixel_weight = scalar(1, mtf_dtype) + \
      liver_t * scalar(FLAGS.xen_liver_weight - 1, mtf_dtype) + \
      lesion_t * scalar(FLAGS.xen_lesion_weight - FLAGS.xen_liver_weight,
                        mtf_dtype)
  loss_xen = mtf.reduce_mean(pixel_loss * pixel_weight)

  # Dice loss
  y_prob = mtf.softmax(y, reduced_dim=label_c_dim)
  lesion_prob = mtf.reduce_sum(mtf.slice(y_prob, 2, 1, 'label_c'),
                               reduced_dim=mtf.Dimension('label_c', 1))
  prob_intersect = mtf.reduce_sum(lesion_prob * lesion_t,
                                  output_shape=mtf.Shape([batch_dim]))
  prob_area_sum = mtf.reduce_sum(lesion_prob + lesion_t,
                                 output_shape=mtf.Shape([batch_dim]))
  loss_dice_per_case = mtf.reduce_mean(
      scalar(-2, mtf_dtype) * prob_intersect / (
          prob_area_sum + scalar(FLAGS.dice_epsilon, mtf_dtype)))
  loss_dice_global = scalar(-2, mtf_dtype) * mtf.reduce_sum(prob_intersect) / (
      mtf.reduce_sum(prob_area_sum) + scalar(FLAGS.dice_epsilon, mtf_dtype))

  loss_dice = (loss_dice_per_case + loss_dice_global) * scalar(0.5, mtf_dtype)

  loss = scalar(FLAGS.dice_loss_weight, mtf_dtype) * loss_dice + scalar(
      1 - FLAGS.dice_loss_weight, mtf_dtype) * loss_xen

  intersect = mtf.reduce_sum(lesion_y * lesion_t,
                             output_shape=mtf.Shape([batch_dim]))
  area_sum = mtf.reduce_sum(lesion_y + lesion_t,
                            output_shape=mtf.Shape([batch_dim]))
  # summary of dice.
  dice_per_case = mtf.reduce_mean(scalar(2, mtf_dtype) * intersect / (
      area_sum + scalar(0.000001, mtf_dtype)))
  dice_global = scalar(2, mtf_dtype) * mtf.reduce_sum(intersect) / (
      mtf.reduce_sum(area_sum) + scalar(0.000001, mtf_dtype))

  eval_metrics = {
      'lesion_pred_ratio': lesion_pred_ratio,
      'lesion_label_ratio': lesion_label_ratio,
      'accuracy_of_all_classes': accuracy,
      'lesion_dice_per_case': dice_per_case,
      'lesion_dice_global': dice_global,
      'loss_xen': loss_xen,
      'loss_dice': loss_dice,
      'loss_dice_per_case': loss_dice_per_case,
      'loss_dice_global': loss_dice_global,
  }

  if FLAGS.sampled_2d_slices:
    y_prob_downsampled = mtf.layers.avg_pool2d(
        y_prob, ksize=(FLAGS.pred_downsample,) * 2)
    if FLAGS.output_ground_truth:
      lesion_gt_downsampled = mtf.layers.avg_pool2d(
          mtf.slice(t, 2, 1, 'label_c'), ksize=(FLAGS.pred_downsample,) * 2)
  else:
    y_prob_downsampled = mtf.layers.avg_pool3d(
        y_prob, ksize=(FLAGS.pred_downsample,) * 3)
    if FLAGS.output_ground_truth:
      lesion_gt_downsampled = mtf.layers.avg_pool3d(
          mtf.slice(t, 2, 1, 'label_c'), ksize=(FLAGS.pred_downsample,) * 3)

  liver_prob_downsampled = mtf.slice(y_prob_downsampled, 1, 1, 'label_c')
  lesion_prob_downsampled = mtf.slice(y_prob_downsampled, 2, 1, 'label_c')
  preds = [
      mtf.reduce_sum(liver_prob_downsampled,
                     reduced_dim=mtf.Dimension('label_c', 1)),
      mtf.reduce_sum(lesion_prob_downsampled,
                     reduced_dim=mtf.Dimension('label_c', 1))]

  if FLAGS.output_ground_truth:
    preds.append(mtf.reduce_sum(
        lesion_gt_downsampled, reduced_dim=mtf.Dimension('label_c', 1)))

  preds.extend([intersect, area_sum])

  return preds, loss, eval_metrics, all_bn_update_ops