def _analyze_quantization_info()

in easy_rec/python/layers/keras/einsum_dense.py [0:0]


def _analyze_quantization_info(equation, input_shape):

  def get_specs(equation, input_shape):
    possible_labels = string.ascii_letters
    dot_replaced_string = re.sub(r'\.\.\.', '0', equation)

    # This is the case where no ellipses are present in the string.
    split_string = re.match('([a-zA-Z]+),([a-zA-Z]+)->([a-zA-Z]+)',
                            dot_replaced_string)
    if split_string is not None:
      input_spec = split_string.group(1)
      weight_spec = split_string.group(2)
      output_spec = split_string.group(3)
      return input_spec, weight_spec, output_spec

    # This is the case where ellipses are present on the left.
    split_string = re.match('0([a-zA-Z]+),([a-zA-Z]+)->0([a-zA-Z]+)',
                            dot_replaced_string)
    if split_string is not None:
      input_spec = split_string.group(1)
      weight_spec = split_string.group(2)
      output_spec = split_string.group(3)
      elided = len(input_shape) - len(input_spec)
      possible_labels = sorted(
          set(possible_labels) - set(input_spec) - set(weight_spec) -
          set(output_spec))
      # Pad labels on the left to `input_spec` and `output_spec`
      for i in range(elided):
        input_spec = possible_labels[i] + input_spec
        output_spec = possible_labels[i] + output_spec
      return input_spec, weight_spec, output_spec

    # This is the case where ellipses are present on the right.
    split_string = re.match('([a-zA-Z]{2,})0,([a-zA-Z]+)->([a-zA-Z]+)0',
                            dot_replaced_string)
    if split_string is not None:
      input_spec = split_string.group(1)
      weight_spec = split_string.group(2)
      output_spec = split_string.group(3)
      elided = len(input_shape) - len(input_spec)
      possible_labels = sorted(
          set(possible_labels) - set(input_spec) - set(weight_spec) -
          set(output_spec))
      # Pad labels on the right to `input_spec` and `output_spec`
      for i in range(elided):
        input_spec = input_spec + possible_labels[i]
        output_spec = output_spec + possible_labels[i]
      return input_spec, weight_spec, output_spec

    raise ValueError(
        "Invalid einsum equation '{equation}'. Equations must be in the "
        'form [X],[Y]->[Z], ...[X],[Y]->...[Z], or [X]...,[Y]->[Z]....'.format(
            equation=equation))

  input_spec, weight_spec, output_spec = get_specs(equation, input_shape)

  # Determine the axes that should be reduced by the quantizer
  input_reduced_axes = []
  weight_reduced_axes = []
  for i, label in enumerate(input_spec):
    index = output_spec.find(label)
    if index == -1:
      input_reduced_axes.append(i)
  for i, label in enumerate(weight_spec):
    index = output_spec.find(label)
    if index == -1:
      weight_reduced_axes.append(i)

  # Determine the axes of `ops.expand_dims`
  input_expand_axes = []
  weight_expand_axes = []
  for i, label in enumerate(output_spec):
    index_input = input_spec.find(label)
    index_weight = weight_spec.find(label)
    if index_input == -1:
      input_expand_axes.append(i)
    if index_weight == -1:
      weight_expand_axes.append(i)

  # Determine the axes of `ops.transpose`
  input_transpose_axes = []
  weight_transpose_axes = []
  for i, label in enumerate(output_spec):
    index_input = input_spec.find(label)
    index_weight = weight_spec.find(label)
    if index_input != -1:
      input_transpose_axes.append(index_input)
    if index_weight != -1:
      weight_transpose_axes.append(index_weight)
  # Postprocess the information:
  # 1. Add dummy axes (1) to transpose_axes
  # 2. Add axis to squeeze_axes if 1. failed
  input_squeeze_axes = []
  weight_squeeze_axes = []
  for ori_index in input_reduced_axes:
    try:
      index = input_expand_axes.pop(0)
    except IndexError:
      input_squeeze_axes.append(ori_index)
    input_transpose_axes.insert(index, ori_index)
  for ori_index in weight_reduced_axes:
    try:
      index = weight_expand_axes.pop(0)
    except IndexError:
      weight_squeeze_axes.append(ori_index)
    weight_transpose_axes.insert(index, ori_index)
  # Prepare equation for `einsum_with_inputs_gradient`
  custom_gradient_equation = '{output_spec},{weight_spec}->{input_spec}'.format(
      output_spec=output_spec, input_spec=input_spec, weight_spec=weight_spec)
  weight_reverse_transpose_axes = [
      i for (_, i) in sorted((v, i)
                             for (i, v) in enumerate(weight_transpose_axes))
  ]
  return (
      input_reduced_axes,
      weight_reduced_axes,
      input_transpose_axes,
      weight_transpose_axes,
      input_expand_axes,
      weight_expand_axes,
      input_squeeze_axes,
      weight_squeeze_axes,
      custom_gradient_equation,
      weight_reverse_transpose_axes,
  )