in easy_rec/python/layers/keras/einsum_dense.py [0:0]
def _analyze_split_string(split_string,
bias_axes,
input_shape,
output_shape,
left_elided=False):
"""Analyze an pre-split einsum string to find the weight shape."""
input_spec = split_string.group(1)
weight_spec = split_string.group(2)
output_spec = split_string.group(3)
elided = len(input_shape) - len(input_spec)
if isinstance(output_shape, int):
output_shape = [output_shape]
else:
output_shape = list(output_shape)
output_shape.insert(0, input_shape[0])
if elided > 0 and left_elided:
for i in range(1, elided):
# We already inserted the 0th input dimension at dim 0, so we need
# to start at location 1 here.
output_shape.insert(1, input_shape[i])
elif elided > 0 and not left_elided:
for i in range(len(input_shape) - elided, len(input_shape)):
output_shape.append(input_shape[i])
if left_elided:
# If we have beginning dimensions elided, we need to use negative
# indexing to determine where in the input dimension our values are.
input_dim_map = {
dim: (i + elided) - len(input_shape) for i, dim in enumerate(input_spec)
}
# Because we've constructed the full output shape already, we don't need
# to do negative indexing.
output_dim_map = {dim: (i + elided) for i, dim in enumerate(output_spec)}
else:
input_dim_map = {dim: i for i, dim in enumerate(input_spec)}
output_dim_map = {dim: i for i, dim in enumerate(output_spec)}
for dim in input_spec:
input_shape_at_dim = input_shape[input_dim_map[dim]]
if dim in output_dim_map:
output_shape_at_dim = output_shape[output_dim_map[dim]]
if (output_shape_at_dim is not None and
output_shape_at_dim != input_shape_at_dim):
raise ValueError(
'Input shape and output shape do not match at shared '
"dimension '{dim}'. Input shape is {input_shape_at_dim}, "
'and output shape is {output_shape}.'.format(
dim=dim,
input_shape_at_dim=input_shape_at_dim,
output_shape=output_shape[output_dim_map[dim]]))
for dim in output_spec:
if dim not in input_spec and dim not in weight_spec:
raise ValueError(
"Dimension '{dim}' was specified in the output "
"'{output_spec}' but has no corresponding dim in the input "
"spec '{input_spec}' or weight spec '{output_spec}'".format(
dim=dim, output_spec=output_spec, input_spec=input_spec))
weight_shape = []
for dim in weight_spec:
if dim in input_dim_map:
weight_shape.append(input_shape[input_dim_map[dim]])
elif dim in output_dim_map:
weight_shape.append(output_shape[output_dim_map[dim]])
else:
raise ValueError(
"Weight dimension '{dim}' did not have a match in either "
"the input spec '{input_spec}' or the output "
"spec '{output_spec}'. For this layer, the weight must "
'be fully specified.'.format(
dim=dim, input_spec=input_spec, output_spec=output_spec))
if bias_axes is not None:
num_left_elided = elided if left_elided else 0
idx_map = {
char: output_shape[i + num_left_elided]
for i, char in enumerate(output_spec)
}
for char in bias_axes:
if char not in output_spec:
raise ValueError(
"Bias dimension '{char}' was requested, but is not part "
"of the output spec '{output_spec}'".format(
char=char, output_spec=output_spec))
first_bias_location = min([output_spec.find(char) for char in bias_axes])
bias_output_spec = output_spec[first_bias_location:]
bias_shape = [
idx_map[char] if char in bias_axes else 1 for char in bias_output_spec
]
if not left_elided:
for _ in range(elided):
bias_shape.append(1)
else:
bias_shape = None
return weight_shape, bias_shape, output_shape