in tensorflow_lattice/python/rtl_layer.py [0:0]
def _get_rtl_structure(self, input_shape):
"""Returns the RTL structure for the given input_shape.
Args:
input_shape: Input shape to the layer. Must be a dict matching the format
described in the layer description.
Raises:
ValueError: If the structure is too small to include all the inputs.
Returns:
A list of `(monotonicities, lattices)` tuples, where `monotonicities` is
the tuple of lattice monotonicites, and `lattices` is a list of list of
indices into the flattened input to the layer.
"""
if not isinstance(input_shape, dict):
input_shape = {'unconstrained': input_shape}
# Calculate the flattened input to the RTL layer. rtl_inputs will be a list
# of _RTLInput items, each including information about the monotonicity,
# input group and input index for each input to the layer.
# The order for flattening should match the order in the call method.
rtl_inputs = []
group = 0 # group id for the input
input_index = 0 # index into the flattened input
for input_key in sorted(input_shape.keys()):
shapes = input_shape[input_key]
if input_key == 'unconstrained':
monotonicity = 0
elif input_key == 'increasing':
monotonicity = 1
else:
raise ValueError(
'Unrecognized key in the input to the RTL layer: {}'.format(
input_key))
if not isinstance(shapes, list):
# Get the shape after a split. See single dense tensor input format in
# the layer comments.
shapes = [(shapes[0], 1)] * shapes[1]
for shape in shapes:
for _ in range(shape[1]):
rtl_inputs.append(
_RTLInput(
monotonicity=monotonicity,
group=group,
input_index=input_index))
input_index += 1
group += 1
total_usage = self.num_lattices * self.lattice_rank
if total_usage < len(rtl_inputs):
raise ValueError(
'RTL layer with {}x{}D lattices is too small to use all the {} input '
'features'.format(self.num_lattices, self.lattice_rank,
len(rtl_inputs)))
# Repeat the features to fill all the slots in the RTL layer.
rs = np.random.RandomState(self.random_seed)
rs.shuffle(rtl_inputs)
rtl_inputs = rtl_inputs * (1 + total_usage // len(rtl_inputs))
rtl_inputs = rtl_inputs[:total_usage]
rs.shuffle(rtl_inputs)
# Start with random lattices, possibly with repeated groups in lattices.
lattices = []
for lattice_index in range(self.num_lattices):
lattices.append(
rtl_inputs[lattice_index * self.lattice_rank:(lattice_index + 1) *
self.lattice_rank])
# Swap features between lattices to make sure only a single input from each
# group is used in each lattice.
changed = True
iteration = 0
while changed and self.avoid_intragroup_interaction:
if iteration > _MAX_RTL_SWAPS:
logging.info('Some lattices in the RTL layer might use features from '
'the same input group')
break
changed = False
iteration += 1
for lattice_0, lattice_1 in itertools.combinations(lattices, 2):
# For every pair of lattices: lattice_0, lattice_1
for index_0, index_1 in itertools.product(
range(len(lattice_0)), range(len(lattice_1))):
# Consider swapping lattice_0[index_0] with lattice_1[index_1]
rest_lattice_0 = list(lattice_0)
rest_lattice_1 = list(lattice_1)
feature_0 = rest_lattice_0.pop(index_0)
feature_1 = rest_lattice_1.pop(index_1)
if feature_0.group == feature_1.group:
continue
# Swap if a group is repeated and a swap fixes it.
rest_lattice_groups_0 = list(
lattice_input.group for lattice_input in rest_lattice_0)
rest_lattice_groups_1 = list(
lattice_input.group for lattice_input in rest_lattice_1)
if ((feature_0.group in rest_lattice_groups_0) and
(feature_0.group not in rest_lattice_groups_1) and
(feature_1.group not in rest_lattice_groups_0)):
lattice_0[index_0], lattice_1[index_1] = (lattice_1[index_1],
lattice_0[index_0])
changed = True
# Arrange into combined lattices layers. Lattices with similar monotonicites
# can use the same tfl.layers.Lattice layer.
# Create a dict: monotonicity -> list of list of input indices.
lattices_for_monotonicities = collections.defaultdict(list)
for lattice in lattices:
lattice.sort(key=lambda lattice_input: lattice_input.monotonicity)
monotonicities = tuple(
lattice_input.monotonicity for lattice_input in lattice)
lattice_input_indices = list(
lattice_input.input_index for lattice_input in lattice)
lattices_for_monotonicities[monotonicities].append(lattice_input_indices)
return sorted(lattices_for_monotonicities.items())