in kfac/python/ops/fisher_factors.py [0:0]
def check_partial_batch_sizes(self):
"""Ensures partial batch sizes are equal across towers and source."""
# While it could be okay in principle to have different batch sizes for
# different towers, the way the code has been written isn't compatible with
# this. Basically, the normalizations occur for each tower and then the
# results are summed across towers and divided by the number of towers.
# The only way this is correct is if the towers all have the same batch
# size.
# Should make these messages use quote characters instead of parentheses
# when the bug with quote character rendering in assertion messages is
# fixed. See b/129476712
msg = ("Inconsistent (partial) batch sizes detected for factor ({}) of type"
" {}. This can be caused by passing Tensors with the wrong sizes to "
"the registration functions, or misspecification of arguments like "
"batch_size, num_uses, or num_timesteps.".format(
self.name, utils.cls_name(self)))
partial_batch_size = self._partial_batch_size()
if self._num_sources > 1 or self._num_towers > 1:
if isinstance(partial_batch_size, int):
checks = tuple(
partial_batch_size == self._partial_batch_size(source=source,
tower=tower)
for source, tower in zip(range(self._num_sources),
range(self._num_towers)))
if not all(checks):
raise ValueError(msg)
return tf.no_op()
else:
asserts = tuple(
tf.assert_equal(partial_batch_size,
self._partial_batch_size(source=source,
tower=tower),
message=msg)
for source, tower in zip(range(self._num_sources),
range(self._num_towers)))
return tf.group(asserts)
return tf.no_op()