in kfac/python/ops/tensormatch/graph_search.py [0:0]
def register_records(layer_collection,
record_list_dict,
reuse=False,
batch_size=None):
"""Registers the given records to layer_collection.
Args:
layer_collection: A `LayerCollection` to use for registering layers.
record_list_dict: A dict mapping tuples of variables to lists of
`MatchRecord`s representing all of the places those variables are used
in the graph.
reuse: (OPTIONAL) bool. If True, then `layer_collection`
selects a previously registered block with the same key as the key
derived from `params` of that block. If False, a new block is
registered.
batch_size: A `int` representing the batch size. Needs to specified if
registering generic variables that don't match any layer patterns or
if time/uses is folded. If the time/uses dimension is merged with
batch then this is used to infer number of uses/time-steps.
Raises:
ValueError: If record_list_dict contains multiple record types for a single
set of variables, or if there are multiple records for a set of variables
of a type that doesn't support shared parameters.
AmbiguousRegistrationError: If a batch norm layer registration is required
but batch_size is not passed.
"""
mixed_record_type_errors = []
# TODO(b/69627702): Layers must be registered in a deterministic order, else
# FisherFactors may end up with different variable names.
params_list = sorted(record_list_dict.keys(), key=str)
for params in params_list:
record_list = record_list_dict[params]
# We don't support mixed types for the same params and probably never
# will.
if not all(record_list[0].record_type == record.record_type
for record in record_list):
mixed_record_type_errors.append(
'Detected variables {} with mixed record types: {}.'.format(
params, record_list))
continue
record_type = record_list[0].record_type
if batch_size:
# If the batch/time dimension is merged in the input then need to set
# `num_uses`.
first_dim = record_list[0].data['inputs'].shape.as_list()[0]
is_batch_time_folded = not (first_dim is None or first_dim == batch_size)
if is_batch_time_folded:
num_uses = first_dim // batch_size
if num_uses == 0:
raise ValueError('It looks like the batch_size passed to the auto-'
'registration function was larger than expected. '
'The likely cause of this is that you passed in '
'the overall batch size instead of the per-replica '
'batch size. When using K-FAC with replication all '
'batch sizes passed to K-FAC and its helper modules '
'should be their per-replica sizes.')
if record_type is RecordType.fully_connected:
if len(record_list) > 1:
logging.info(
'Registering as multi fully-connected: {}'.format(params))
inputs = tuple(record.data['inputs'] for record in record_list)
outputs = tuple(record.data['outputs'] for record in record_list)
layer_collection.register_fully_connected_multi(
params, inputs, outputs, reuse=reuse)
else:
logging.info('Registering as fully-connected: {}'.format(params))
record = record_list[0]
inputs = record.data['inputs']
outputs = record.data['outputs']
if batch_size and is_batch_time_folded:
layer_collection.register_fully_connected_multi(
params, inputs, outputs, num_uses=num_uses, reuse=reuse)
else:
layer_collection.register_fully_connected(
params, inputs, outputs, reuse=reuse)
elif record_type is RecordType.conv2d:
if len(record_list) > 1:
logging.info('Registering as multi conv2d: {}'.format(params))
inputs = tuple(record.data['inputs'] for record in record_list)
outputs = tuple(record.data['outputs'] for record in record_list)
strides = record_list[0].data['strides']
padding = record_list[0].data['padding']
data_format = record_list[0].data['data_format']
layer_collection.register_conv2d_multi(
params,
strides,
padding,
inputs,
outputs,
data_format=data_format,
reuse=reuse)
else:
logging.info('Registering as conv2d: {}'.format(params))
record = record_list[0]
inputs = record.data['inputs']
outputs = record.data['outputs']
strides = record.data['strides']
padding = record.data['padding']
data_format = record.data['data_format']
if batch_size and is_batch_time_folded:
layer_collection.register_conv2d_multi(
params,
strides,
padding,
inputs,
outputs,
data_format=data_format,
num_uses=num_uses,
reuse=reuse)
else:
layer_collection.register_conv2d(params, strides, padding, inputs,
outputs, data_format=data_format,
reuse=reuse)
elif record_type is RecordType.scale_and_shift:
logging.info('Registering as scale (& shift): {}'.format(params))
if len(record_list) > 1:
raise ValueError('Multi-use registrations currently not supported for '
'scale & shift operations.')
record = record_list[0]
inputs = record.data['inputs']
outputs = record.data['outputs']
layer_collection.register_scale_and_shift(params, inputs, outputs,
reuse=reuse)
elif record_type is RecordType.batch_norm:
logging.info('Registering as batch norm: {}'.format(params))
if batch_size is None:
raise AmbiguousRegistrationError(
'Tried to register a batch norm layer (as generic) without '
'knowledge of batch_size. You can pass batch_size in to fix this '
'error.')
# This is a slight hack. Ideally register_generic would work with lists
# of params like it used to before we switched to the "unflattened" cov
# representation so we wouldn't need to detect the approximation type.
will_use_diag = (
layer_collection._get_linked_approx(params) == 'diagonal' # pylint: disable=protected-access
or (layer_collection.default_generic_approximation == 'diagonal'
and layer_collection._get_linked_approx(params) is None) # pylint: disable=protected-access
)
if will_use_diag:
layer_collection.register_generic(params[0], batch_size, reuse=reuse)
layer_collection.register_generic(params[1], batch_size, reuse=reuse)
else:
layer_collection.register_generic(params, batch_size, reuse=reuse)
else:
assert False, 'Invalid record type {}'.format(record_type)
if mixed_record_type_errors:
raise ValueError('\n'.join(mixed_record_type_errors))