in easy_rec/python/layers/keras/einsum_dense.py [0:0]
def _analyze_quantization_info(equation, input_shape):
def get_specs(equation, input_shape):
possible_labels = string.ascii_letters
dot_replaced_string = re.sub(r'\.\.\.', '0', equation)
# This is the case where no ellipses are present in the string.
split_string = re.match('([a-zA-Z]+),([a-zA-Z]+)->([a-zA-Z]+)',
dot_replaced_string)
if split_string is not None:
input_spec = split_string.group(1)
weight_spec = split_string.group(2)
output_spec = split_string.group(3)
return input_spec, weight_spec, output_spec
# This is the case where ellipses are present on the left.
split_string = re.match('0([a-zA-Z]+),([a-zA-Z]+)->0([a-zA-Z]+)',
dot_replaced_string)
if split_string is not None:
input_spec = split_string.group(1)
weight_spec = split_string.group(2)
output_spec = split_string.group(3)
elided = len(input_shape) - len(input_spec)
possible_labels = sorted(
set(possible_labels) - set(input_spec) - set(weight_spec) -
set(output_spec))
# Pad labels on the left to `input_spec` and `output_spec`
for i in range(elided):
input_spec = possible_labels[i] + input_spec
output_spec = possible_labels[i] + output_spec
return input_spec, weight_spec, output_spec
# This is the case where ellipses are present on the right.
split_string = re.match('([a-zA-Z]{2,})0,([a-zA-Z]+)->([a-zA-Z]+)0',
dot_replaced_string)
if split_string is not None:
input_spec = split_string.group(1)
weight_spec = split_string.group(2)
output_spec = split_string.group(3)
elided = len(input_shape) - len(input_spec)
possible_labels = sorted(
set(possible_labels) - set(input_spec) - set(weight_spec) -
set(output_spec))
# Pad labels on the right to `input_spec` and `output_spec`
for i in range(elided):
input_spec = input_spec + possible_labels[i]
output_spec = output_spec + possible_labels[i]
return input_spec, weight_spec, output_spec
raise ValueError(
"Invalid einsum equation '{equation}'. Equations must be in the "
'form [X],[Y]->[Z], ...[X],[Y]->...[Z], or [X]...,[Y]->[Z]....'.format(
equation=equation))
input_spec, weight_spec, output_spec = get_specs(equation, input_shape)
# Determine the axes that should be reduced by the quantizer
input_reduced_axes = []
weight_reduced_axes = []
for i, label in enumerate(input_spec):
index = output_spec.find(label)
if index == -1:
input_reduced_axes.append(i)
for i, label in enumerate(weight_spec):
index = output_spec.find(label)
if index == -1:
weight_reduced_axes.append(i)
# Determine the axes of `ops.expand_dims`
input_expand_axes = []
weight_expand_axes = []
for i, label in enumerate(output_spec):
index_input = input_spec.find(label)
index_weight = weight_spec.find(label)
if index_input == -1:
input_expand_axes.append(i)
if index_weight == -1:
weight_expand_axes.append(i)
# Determine the axes of `ops.transpose`
input_transpose_axes = []
weight_transpose_axes = []
for i, label in enumerate(output_spec):
index_input = input_spec.find(label)
index_weight = weight_spec.find(label)
if index_input != -1:
input_transpose_axes.append(index_input)
if index_weight != -1:
weight_transpose_axes.append(index_weight)
# Postprocess the information:
# 1. Add dummy axes (1) to transpose_axes
# 2. Add axis to squeeze_axes if 1. failed
input_squeeze_axes = []
weight_squeeze_axes = []
for ori_index in input_reduced_axes:
try:
index = input_expand_axes.pop(0)
except IndexError:
input_squeeze_axes.append(ori_index)
input_transpose_axes.insert(index, ori_index)
for ori_index in weight_reduced_axes:
try:
index = weight_expand_axes.pop(0)
except IndexError:
weight_squeeze_axes.append(ori_index)
weight_transpose_axes.insert(index, ori_index)
# Prepare equation for `einsum_with_inputs_gradient`
custom_gradient_equation = '{output_spec},{weight_spec}->{input_spec}'.format(
output_spec=output_spec, input_spec=input_spec, weight_spec=weight_spec)
weight_reverse_transpose_axes = [
i for (_, i) in sorted((v, i)
for (i, v) in enumerate(weight_transpose_axes))
]
return (
input_reduced_axes,
weight_reduced_axes,
input_transpose_axes,
weight_transpose_axes,
input_expand_axes,
weight_expand_axes,
input_squeeze_axes,
weight_squeeze_axes,
custom_gradient_equation,
weight_reverse_transpose_axes,
)