def check_partial_batch_sizes()

in kfac/python/ops/fisher_factors.py [0:0]


  def check_partial_batch_sizes(self):
    """Ensures partial batch sizes are equal across towers and source."""

    # While it could be okay in principle to have different batch sizes for
    # different towers, the way the code has been written isn't compatible with
    # this. Basically, the normalizations occur for each tower and then the
    # results are summed across towers and divided by the number of towers.
    # The only way this is correct is if the towers all have the same batch
    # size.

    # Should make these messages use quote characters instead of parentheses
    # when the bug with quote character rendering in assertion messages is
    # fixed. See b/129476712
    msg = ("Inconsistent (partial) batch sizes detected for factor ({}) of type"
           " {}. This can be caused by passing Tensors with the wrong sizes to "
           "the registration functions, or misspecification of arguments like "
           "batch_size, num_uses, or num_timesteps.".format(
               self.name, utils.cls_name(self)))

    partial_batch_size = self._partial_batch_size()

    if self._num_sources > 1 or self._num_towers > 1:
      if isinstance(partial_batch_size, int):
        checks = tuple(
            partial_batch_size == self._partial_batch_size(source=source,
                                                           tower=tower)
            for source, tower in zip(range(self._num_sources),
                                     range(self._num_towers)))
        if not all(checks):
          raise ValueError(msg)

        return tf.no_op()

      else:
        asserts = tuple(
            tf.assert_equal(partial_batch_size,
                            self._partial_batch_size(source=source,
                                                     tower=tower),
                            message=msg)
            for source, tower in zip(range(self._num_sources),
                                     range(self._num_towers)))
        return tf.group(asserts)

    return tf.no_op()