def _combine_distributed_scaffold()

in tensorflow_estimator/python/estimator/estimator.py [0:0]


def _combine_distributed_scaffold(grouped_scaffold, distribution):
  """Combines scaffold(s) returned from `call_for_each_replica`."""

  # TODO(anjalisridhar): Figure out how to resolve the following scaffold
  # parameters: init_feed_dict, init_fn.
  scaffold_list = distribution.experimental_local_results(grouped_scaffold)
  init_feed_dict = [
      s.init_feed_dict for s in scaffold_list if s.init_feed_dict is not None
  ]
  if init_feed_dict:
    init_feed_dict = distribution.group(init_feed_dict)
  else:
    init_feed_dict = None

  init_fn = [
      s._user_init_fn for s in scaffold_list if s._user_init_fn is not None  # pylint: disable=protected-access
  ]
  if init_fn:
    init_fn = init_fn[0]
  else:
    init_fn = None

  init_op = [s.init_op for s in scaffold_list if s.init_op is not None]
  if init_op:
    init_op = distribution.group(init_op)
  else:
    init_op = None

  def _unwrap_and_concat(value):
    value = tf.nest.flatten(distribution.experimental_local_results(value))
    if len(value) != 1:
      return tf.concat(value, 0)
    return value[0]

  ready_op = distribution.extended.call_for_each_replica(
      lambda scaffold: scaffold.ready_op, args=(grouped_scaffold,))
  if ready_op is not None:
    ready_op = _unwrap_and_concat(ready_op)

  ready_for_local_init_op = distribution.extended.call_for_each_replica(
      create_per_replica_ready_for_local_init_op, args=(grouped_scaffold,))
  if ready_for_local_init_op is not None:
    ready_for_local_init_op = _unwrap_and_concat(ready_for_local_init_op)
  else:
    ready_for_local_init_op = None

  local_init_op = [
      s.local_init_op for s in scaffold_list if s.local_init_op is not None
  ]
  if local_init_op:
    local_init_op = distribution.group(local_init_op)
  else:
    local_init_op = None

  summary_op = [s.summary_op for s in scaffold_list if s.summary_op is not None]
  if summary_op:
    summary_op = distribution.group(summary_op)
  else:
    summary_op = None

  savers = [s.saver for s in scaffold_list if s.saver is not None]
  if savers:
    saver = savers[0]
  else:
    saver = None

  scaffold = tf.compat.v1.train.Scaffold(
      init_op=init_op,
      ready_op=ready_op,
      ready_for_local_init_op=ready_for_local_init_op,
      local_init_op=local_init_op,
      summary_op=summary_op,
      saver=saver,
      init_feed_dict=init_feed_dict,
      init_fn=init_fn)
  return scaffold