def infer_return_type_func()

in sdks/python/apache_beam/typehints/trivial_inference.py [0:0]


def infer_return_type_func(f, input_types, debug=False, depth=0):
  """Analyses a function to deduce its return type.

  Args:
    f: A Python function object to infer the return type of.
    input_types: A sequence of inputs corresponding to the input types.
    debug: Whether to print verbose debugging information.
    depth: Maximum inspection depth during type inference.

  Returns:
    A TypeConstraint that that the return value of this function will (likely)
    satisfy given the specified inputs.

  Raises:
    TypeInferenceError: if no type can be inferred.
  """
  if debug:
    print()
    print(f, id(f), input_types)
    if (sys.version_info.major, sys.version_info.minor) >= (3, 11):
      dis.dis(f, show_caches=True)
    else:
      dis.dis(f)
  from . import opcodes
  simple_ops = dict((k.upper(), v) for k, v in opcodes.__dict__.items())
  from . import intrinsic_one_ops

  co = f.__code__
  code = co.co_code
  end = len(code)
  pc = 0
  free = None

  yields = set()
  returns = set()
  # TODO(robertwb): Default args via inspect module.
  local_vars = list(input_types) + [typehints.Union[()]] * (
      len(co.co_varnames) - len(input_types))
  state = FrameState(f, local_vars)
  states = collections.defaultdict(lambda: None)
  jumps = collections.defaultdict(int)

  # In Python 3, use dis library functions to disassemble bytecode and handle
  # EXTENDED_ARGs.
  ofs_table = {}  # offset -> instruction
  if (sys.version_info.major, sys.version_info.minor) >= (3, 11):
    dis_ints = dis.get_instructions(f, show_caches=True)
  else:
    dis_ints = dis.get_instructions(f)

  for instruction in dis_ints:
    ofs_table[instruction.offset] = instruction

  # Python 3.6+: 1 byte opcode + 1 byte arg (2 bytes, arg may be ignored).
  inst_size = 2
  opt_arg_size = 0

  # Python 3.10: bpo-27129 changes jump offsets to use instruction offsets,
  # not byte offsets. The offsets were halved (16 bits fro instructions vs 8
  # bits for bytes), so we have to double the value of arg.
  if (sys.version_info.major, sys.version_info.minor) >= (3, 10):
    jump_multiplier = 2
  else:
    jump_multiplier = 1

  last_pc = -1
  last_real_opname = opname = None
  while pc < end:  # pylint: disable=too-many-nested-blocks
    if opname not in ('PRECALL', 'CACHE'):
      last_real_opname = opname
    start = pc
    instruction = ofs_table[pc]
    op = instruction.opcode
    if debug:
      print('-->' if pc == last_pc else '    ', end=' ')
      print(repr(pc).rjust(4), end=' ')
      print(dis.opname[op].ljust(20), end=' ')

    pc += inst_size
    arg = None
    if op >= dis.HAVE_ARGUMENT:
      arg = instruction.arg
      pc += opt_arg_size
      if debug:
        print(str(arg).rjust(5), end=' ')
        if op in dis.hasconst:
          print('(' + repr(co.co_consts[arg]) + ')', end=' ')
        elif op in dis.hasname:
          if (sys.version_info.major, sys.version_info.minor) >= (3, 11):
            # Pre-emptively bit-shift so the print doesn't go out of index
            print_arg = arg >> 1
          else:
            print_arg = arg
          print('(' + co.co_names[print_arg] + ')', end=' ')
        elif op in dis.hasjrel:
          print('(to ' + repr(pc + (arg * jump_multiplier)) + ')', end=' ')
        elif op in dis.haslocal:
          print('(' + co.co_varnames[arg] + ')', end=' ')
        elif op in dis.hascompare:
          if (sys.version_info.major, sys.version_info.minor) >= (3, 12):
            # In 3.12 this arg was bit-shifted. Shifting it back avoids an
            # out-of-index.
            arg = arg >> 4
          print('(' + dis.cmp_op[arg] + ')', end=' ')
        elif op in dis.hasfree:
          if free is None:
            free = co.co_cellvars + co.co_freevars
          # From 3.11 on the arg is no longer offset by len(co_varnames)
          # so we adjust it back
          print_arg = arg
          if (sys.version_info.major, sys.version_info.minor) >= (3, 11):
            print_arg = arg - len(co.co_varnames)
          print('(' + free[print_arg] + ')', end=' ')

    # Actually emulate the op.
    if state is None and states[start] is None:
      # No control reaches here (yet).
      if debug:
        print()
      continue
    state |= states[start]

    opname = dis.opname[op]
    jmp = jmp_state = None
    if opname.startswith('CALL_FUNCTION'):
      if opname == 'CALL_FUNCTION':
        pop_count = arg + 1
        if depth <= 0:
          return_type = Any
        elif isinstance(state.stack[-pop_count], Const):
          return_type = infer_return_type(
              state.stack[-pop_count].value,
              state.stack[1 - pop_count:],
              debug=debug,
              depth=depth - 1)
        else:
          return_type = Any
      elif opname == 'CALL_FUNCTION_KW':
        # TODO(BEAM-24755): Handle keyword arguments. Requires passing them by
        #   name to infer_return_type.
        pop_count = arg + 2
        if isinstance(state.stack[-pop_count], Const):
          from apache_beam.pvalue import Row
          if state.stack[-pop_count].value == Row:
            fields = state.stack[-1].value
            return_type = row_type.RowTypeConstraint.from_fields(
                list(
                    zip(
                        fields,
                        Const.unwrap_all(state.stack[-pop_count + 1:-1]))))
          else:
            return_type = Any
        else:
          return_type = Any
      elif opname == 'CALL_FUNCTION_EX':
        # stack[-has_kwargs]: Map of keyword args.
        # stack[-1 - has_kwargs]: Iterable of positional args.
        # stack[-2 - has_kwargs]: Function to call.
        has_kwargs: int = arg & 1
        pop_count = has_kwargs + 2
        if has_kwargs:
          # TODO(BEAM-24755): Unimplemented. Requires same functionality as a
          #   CALL_FUNCTION_KW implementation.
          return_type = Any
        else:
          args = state.stack[-1]
          _callable = state.stack[-2]
          if isinstance(args, typehints.ListConstraint):
            # Case where there's a single var_arg argument.
            args = [args]
          elif isinstance(args, typehints.TupleConstraint):
            args = list(args._inner_types())
          elif isinstance(args, typehints.SequenceTypeConstraint):
            args = [element_type(args)] * len(
                inspect.getfullargspec(_callable.value).args)
          return_type = infer_return_type(
              _callable.value, args, debug=debug, depth=depth - 1)
      else:
        raise TypeInferenceError('unable to handle %s' % opname)
      state.stack[-pop_count:] = [return_type]
    elif opname == 'CALL_METHOD':
      pop_count = 1 + arg
      # LOAD_METHOD will return a non-Const (Any) if loading from an Any.
      if isinstance(state.stack[-pop_count], Const) and depth > 0:
        return_type = infer_return_type(
            state.stack[-pop_count].value,
            state.stack[1 - pop_count:],
            debug=debug,
            depth=depth - 1)
      else:
        return_type = typehints.Any
      state.stack[-pop_count:] = [return_type]
    elif opname == 'CALL':
      pop_count = 1 + arg
      # Keyword Args case
      if state.kw_names is not None:
        if isinstance(state.stack[-pop_count], Const):
          from apache_beam.pvalue import Row
          if state.stack[-pop_count].value == Row:
            fields = state.kw_names
            return_type = row_type.RowTypeConstraint.from_fields(
                list(
                    zip(fields,
                        Const.unwrap_all(state.stack[-pop_count + 1:]))))
          else:
            return_type = Any
        state.kw_names = None
      else:
        # Handle comprehensions always having an arg of 0 for CALL
        # See https://github.com/python/cpython/issues/102403 for context.
        if (pop_count == 1 and last_real_opname == 'GET_ITER' and
            len(state.stack) > 1 and isinstance(state.stack[-2], Const) and
            getattr(state.stack[-2].value, '__name__', None) in (
                '<listcomp>', '<dictcomp>', '<setcomp>', '<genexpr>')):
          pop_count += 1
        if depth <= 0 or pop_count > len(state.stack):
          return_type = Any
        elif isinstance(state.stack[-pop_count], Const):
          return_type = infer_return_type(
              state.stack[-pop_count].value,
              state.stack[1 - pop_count:],
              debug=debug,
              depth=depth - 1)
        else:
          return_type = Any
      state.stack[-pop_count:] = [return_type]
    elif opname in simple_ops:
      if debug:
        print("Executing simple op " + opname)
      simple_ops[opname](state, arg)
    elif opname == 'RETURN_VALUE':
      returns.add(state.stack[-1])
      state = None
    elif opname == 'YIELD_VALUE':
      yields.add(state.stack[-1])
    elif opname == 'JUMP_FORWARD':
      jmp = pc + arg * jump_multiplier
      jmp_state = state
      state = None
    elif opname in ('JUMP_BACKWARD', 'JUMP_BACKWARD_NO_INTERRUPT'):
      jmp = pc - (arg * jump_multiplier)
      jmp_state = state
      state = None
    elif opname == 'JUMP_ABSOLUTE':
      jmp = arg * jump_multiplier
      jmp_state = state
      state = None
    elif opname in ('POP_JUMP_IF_TRUE', 'POP_JUMP_IF_FALSE'):
      state.stack.pop()
      # The arg was changed to be a relative delta instead of an absolute
      # in 3.11, and became a full instruction instead of a
      # pseudo-instruction in 3.12
      if (sys.version_info.major, sys.version_info.minor) >= (3, 12):
        jmp = pc + arg * jump_multiplier
      else:
        jmp = arg * jump_multiplier
      jmp_state = state.copy()
    elif opname in ('POP_JUMP_FORWARD_IF_TRUE', 'POP_JUMP_FORWARD_IF_FALSE'):
      state.stack.pop()
      jmp = pc + arg * jump_multiplier
      jmp_state = state.copy()
    elif opname in ('POP_JUMP_BACKWARD_IF_TRUE', 'POP_JUMP_BACKWARD_IF_FALSE'):
      state.stack.pop()
      jmp = pc - (arg * jump_multiplier)
      jmp_state = state.copy()
    elif opname in ('POP_JUMP_FORWARD_IF_NONE', 'POP_JUMP_FORWARD_IF_NOT_NONE'):
      state.stack.pop()
      jmp = pc + arg * jump_multiplier
      jmp_state = state.copy()
    elif opname in ('POP_JUMP_BACKWARD_IF_NONE',
                    'POP_JUMP_BACKWARD_IF_NOT_NONE'):
      state.stack.pop()
      jmp = pc - (arg * jump_multiplier)
      jmp_state = state.copy()
    elif opname in ('JUMP_IF_TRUE_OR_POP', 'JUMP_IF_FALSE_OR_POP'):
      # The arg was changed to be a relative delta instead of an absolute
      # in 3.11
      if (sys.version_info.major, sys.version_info.minor) >= (3, 11):
        jmp = pc + arg * jump_multiplier
      else:
        jmp = arg * jump_multiplier
      jmp_state = state.copy()
      state.stack.pop()
    elif opname == 'FOR_ITER':
      jmp = pc + arg * jump_multiplier
      if sys.version_info >= (3, 12):
        # The jump is relative to the next instruction after a cache call,
        # so jump 4 more bytes.
        jmp += 4
      jmp_state = state.copy()
      jmp_state.stack.pop()
      state.stack.append(element_type(state.stack[-1]))
    elif opname == 'COPY_FREE_VARS':
      # Helps with calling closures, but since we aren't executing
      # them we can treat this as a no-op
      pass
    elif opname == 'KW_NAMES':
      tup = co.co_consts[arg]
      state.kw_names = tup
    elif opname == 'RESUME':
      # RESUME is a no-op
      pass
    elif opname == 'PUSH_NULL':
      # We're treating this as a no-op to avoid having to check
      # for extra None values on the stack when we extract return
      # values
      pass
    elif opname == 'PRECALL':
      # PRECALL is a no-op.
      pass
    elif opname == 'MAKE_CELL':
      # TODO: see if we need to implement cells like this
      pass
    elif opname == 'RETURN_GENERATOR':
      # TODO: see what this behavior is supposed to be beyond
      # putting something on the stack to be popped off
      state.stack.append(None)
      pass
    elif opname == 'CACHE':
      # No-op introduced in 3.11. Without handling this some
      # instructions have functionally > 2 byte size.
      pass
    elif opname == 'RETURN_CONST':
      # Introduced in 3.12. Handles returning constants directly
      # instead of having a LOAD_CONST before a RETURN_VALUE.
      returns.add(state.const_type(arg))
      state = None
    elif opname == 'CALL_INTRINSIC_1':
      # Introduced in 3.12. The arg is an index into a table of
      # operations reproduced in INT_ONE_OPS. Not all ops are
      # relevant for our type checking infrastructure.
      int_op = intrinsic_one_ops.INT_ONE_OPS[arg]
      if debug:
        print("Executing intrinsic one op", int_op.__name__.upper())
      int_op(state, arg)

    else:
      raise TypeInferenceError('unable to handle %s' % opname)

    if jmp is not None:
      # TODO(robertwb): Is this guaranteed to converge?
      new_state = states[jmp] | jmp_state
      if jmp < pc and new_state != states[jmp] and jumps[pc] < 5:
        jumps[pc] += 1
        pc = jmp
      states[jmp] = new_state

    if debug:
      print()
      print(state)
      pprint.pprint(dict(item for item in states.items() if item[1]))

  if yields:
    result = typehints.Iterable[reduce(union, Const.unwrap_all(yields))]
  else:
    result = reduce(union, Const.unwrap_all(returns))
  finalize_hints(result)

  if debug:
    print(f, id(f), input_types, '->', result)
  return result