in sdks/python/apache_beam/typehints/trivial_inference.py [0:0]
def infer_return_type_func(f, input_types, debug=False, depth=0):
"""Analyses a function to deduce its return type.
Args:
f: A Python function object to infer the return type of.
input_types: A sequence of inputs corresponding to the input types.
debug: Whether to print verbose debugging information.
depth: Maximum inspection depth during type inference.
Returns:
A TypeConstraint that that the return value of this function will (likely)
satisfy given the specified inputs.
Raises:
TypeInferenceError: if no type can be inferred.
"""
if debug:
print()
print(f, id(f), input_types)
if (sys.version_info.major, sys.version_info.minor) >= (3, 11):
dis.dis(f, show_caches=True)
else:
dis.dis(f)
from . import opcodes
simple_ops = dict((k.upper(), v) for k, v in opcodes.__dict__.items())
from . import intrinsic_one_ops
co = f.__code__
code = co.co_code
end = len(code)
pc = 0
free = None
yields = set()
returns = set()
# TODO(robertwb): Default args via inspect module.
local_vars = list(input_types) + [typehints.Union[()]] * (
len(co.co_varnames) - len(input_types))
state = FrameState(f, local_vars)
states = collections.defaultdict(lambda: None)
jumps = collections.defaultdict(int)
# In Python 3, use dis library functions to disassemble bytecode and handle
# EXTENDED_ARGs.
ofs_table = {} # offset -> instruction
if (sys.version_info.major, sys.version_info.minor) >= (3, 11):
dis_ints = dis.get_instructions(f, show_caches=True)
else:
dis_ints = dis.get_instructions(f)
for instruction in dis_ints:
ofs_table[instruction.offset] = instruction
# Python 3.6+: 1 byte opcode + 1 byte arg (2 bytes, arg may be ignored).
inst_size = 2
opt_arg_size = 0
# Python 3.10: bpo-27129 changes jump offsets to use instruction offsets,
# not byte offsets. The offsets were halved (16 bits fro instructions vs 8
# bits for bytes), so we have to double the value of arg.
if (sys.version_info.major, sys.version_info.minor) >= (3, 10):
jump_multiplier = 2
else:
jump_multiplier = 1
last_pc = -1
last_real_opname = opname = None
while pc < end: # pylint: disable=too-many-nested-blocks
if opname not in ('PRECALL', 'CACHE'):
last_real_opname = opname
start = pc
instruction = ofs_table[pc]
op = instruction.opcode
if debug:
print('-->' if pc == last_pc else ' ', end=' ')
print(repr(pc).rjust(4), end=' ')
print(dis.opname[op].ljust(20), end=' ')
pc += inst_size
arg = None
if op >= dis.HAVE_ARGUMENT:
arg = instruction.arg
pc += opt_arg_size
if debug:
print(str(arg).rjust(5), end=' ')
if op in dis.hasconst:
print('(' + repr(co.co_consts[arg]) + ')', end=' ')
elif op in dis.hasname:
if (sys.version_info.major, sys.version_info.minor) >= (3, 11):
# Pre-emptively bit-shift so the print doesn't go out of index
print_arg = arg >> 1
else:
print_arg = arg
print('(' + co.co_names[print_arg] + ')', end=' ')
elif op in dis.hasjrel:
print('(to ' + repr(pc + (arg * jump_multiplier)) + ')', end=' ')
elif op in dis.haslocal:
print('(' + co.co_varnames[arg] + ')', end=' ')
elif op in dis.hascompare:
if (sys.version_info.major, sys.version_info.minor) >= (3, 12):
# In 3.12 this arg was bit-shifted. Shifting it back avoids an
# out-of-index.
arg = arg >> 4
print('(' + dis.cmp_op[arg] + ')', end=' ')
elif op in dis.hasfree:
if free is None:
free = co.co_cellvars + co.co_freevars
# From 3.11 on the arg is no longer offset by len(co_varnames)
# so we adjust it back
print_arg = arg
if (sys.version_info.major, sys.version_info.minor) >= (3, 11):
print_arg = arg - len(co.co_varnames)
print('(' + free[print_arg] + ')', end=' ')
# Actually emulate the op.
if state is None and states[start] is None:
# No control reaches here (yet).
if debug:
print()
continue
state |= states[start]
opname = dis.opname[op]
jmp = jmp_state = None
if opname.startswith('CALL_FUNCTION'):
if opname == 'CALL_FUNCTION':
pop_count = arg + 1
if depth <= 0:
return_type = Any
elif isinstance(state.stack[-pop_count], Const):
return_type = infer_return_type(
state.stack[-pop_count].value,
state.stack[1 - pop_count:],
debug=debug,
depth=depth - 1)
else:
return_type = Any
elif opname == 'CALL_FUNCTION_KW':
# TODO(BEAM-24755): Handle keyword arguments. Requires passing them by
# name to infer_return_type.
pop_count = arg + 2
if isinstance(state.stack[-pop_count], Const):
from apache_beam.pvalue import Row
if state.stack[-pop_count].value == Row:
fields = state.stack[-1].value
return_type = row_type.RowTypeConstraint.from_fields(
list(
zip(
fields,
Const.unwrap_all(state.stack[-pop_count + 1:-1]))))
else:
return_type = Any
else:
return_type = Any
elif opname == 'CALL_FUNCTION_EX':
# stack[-has_kwargs]: Map of keyword args.
# stack[-1 - has_kwargs]: Iterable of positional args.
# stack[-2 - has_kwargs]: Function to call.
has_kwargs: int = arg & 1
pop_count = has_kwargs + 2
if has_kwargs:
# TODO(BEAM-24755): Unimplemented. Requires same functionality as a
# CALL_FUNCTION_KW implementation.
return_type = Any
else:
args = state.stack[-1]
_callable = state.stack[-2]
if isinstance(args, typehints.ListConstraint):
# Case where there's a single var_arg argument.
args = [args]
elif isinstance(args, typehints.TupleConstraint):
args = list(args._inner_types())
elif isinstance(args, typehints.SequenceTypeConstraint):
args = [element_type(args)] * len(
inspect.getfullargspec(_callable.value).args)
return_type = infer_return_type(
_callable.value, args, debug=debug, depth=depth - 1)
else:
raise TypeInferenceError('unable to handle %s' % opname)
state.stack[-pop_count:] = [return_type]
elif opname == 'CALL_METHOD':
pop_count = 1 + arg
# LOAD_METHOD will return a non-Const (Any) if loading from an Any.
if isinstance(state.stack[-pop_count], Const) and depth > 0:
return_type = infer_return_type(
state.stack[-pop_count].value,
state.stack[1 - pop_count:],
debug=debug,
depth=depth - 1)
else:
return_type = typehints.Any
state.stack[-pop_count:] = [return_type]
elif opname == 'CALL':
pop_count = 1 + arg
# Keyword Args case
if state.kw_names is not None:
if isinstance(state.stack[-pop_count], Const):
from apache_beam.pvalue import Row
if state.stack[-pop_count].value == Row:
fields = state.kw_names
return_type = row_type.RowTypeConstraint.from_fields(
list(
zip(fields,
Const.unwrap_all(state.stack[-pop_count + 1:]))))
else:
return_type = Any
state.kw_names = None
else:
# Handle comprehensions always having an arg of 0 for CALL
# See https://github.com/python/cpython/issues/102403 for context.
if (pop_count == 1 and last_real_opname == 'GET_ITER' and
len(state.stack) > 1 and isinstance(state.stack[-2], Const) and
getattr(state.stack[-2].value, '__name__', None) in (
'<listcomp>', '<dictcomp>', '<setcomp>', '<genexpr>')):
pop_count += 1
if depth <= 0 or pop_count > len(state.stack):
return_type = Any
elif isinstance(state.stack[-pop_count], Const):
return_type = infer_return_type(
state.stack[-pop_count].value,
state.stack[1 - pop_count:],
debug=debug,
depth=depth - 1)
else:
return_type = Any
state.stack[-pop_count:] = [return_type]
elif opname in simple_ops:
if debug:
print("Executing simple op " + opname)
simple_ops[opname](state, arg)
elif opname == 'RETURN_VALUE':
returns.add(state.stack[-1])
state = None
elif opname == 'YIELD_VALUE':
yields.add(state.stack[-1])
elif opname == 'JUMP_FORWARD':
jmp = pc + arg * jump_multiplier
jmp_state = state
state = None
elif opname in ('JUMP_BACKWARD', 'JUMP_BACKWARD_NO_INTERRUPT'):
jmp = pc - (arg * jump_multiplier)
jmp_state = state
state = None
elif opname == 'JUMP_ABSOLUTE':
jmp = arg * jump_multiplier
jmp_state = state
state = None
elif opname in ('POP_JUMP_IF_TRUE', 'POP_JUMP_IF_FALSE'):
state.stack.pop()
# The arg was changed to be a relative delta instead of an absolute
# in 3.11, and became a full instruction instead of a
# pseudo-instruction in 3.12
if (sys.version_info.major, sys.version_info.minor) >= (3, 12):
jmp = pc + arg * jump_multiplier
else:
jmp = arg * jump_multiplier
jmp_state = state.copy()
elif opname in ('POP_JUMP_FORWARD_IF_TRUE', 'POP_JUMP_FORWARD_IF_FALSE'):
state.stack.pop()
jmp = pc + arg * jump_multiplier
jmp_state = state.copy()
elif opname in ('POP_JUMP_BACKWARD_IF_TRUE', 'POP_JUMP_BACKWARD_IF_FALSE'):
state.stack.pop()
jmp = pc - (arg * jump_multiplier)
jmp_state = state.copy()
elif opname in ('POP_JUMP_FORWARD_IF_NONE', 'POP_JUMP_FORWARD_IF_NOT_NONE'):
state.stack.pop()
jmp = pc + arg * jump_multiplier
jmp_state = state.copy()
elif opname in ('POP_JUMP_BACKWARD_IF_NONE',
'POP_JUMP_BACKWARD_IF_NOT_NONE'):
state.stack.pop()
jmp = pc - (arg * jump_multiplier)
jmp_state = state.copy()
elif opname in ('JUMP_IF_TRUE_OR_POP', 'JUMP_IF_FALSE_OR_POP'):
# The arg was changed to be a relative delta instead of an absolute
# in 3.11
if (sys.version_info.major, sys.version_info.minor) >= (3, 11):
jmp = pc + arg * jump_multiplier
else:
jmp = arg * jump_multiplier
jmp_state = state.copy()
state.stack.pop()
elif opname == 'FOR_ITER':
jmp = pc + arg * jump_multiplier
if sys.version_info >= (3, 12):
# The jump is relative to the next instruction after a cache call,
# so jump 4 more bytes.
jmp += 4
jmp_state = state.copy()
jmp_state.stack.pop()
state.stack.append(element_type(state.stack[-1]))
elif opname == 'COPY_FREE_VARS':
# Helps with calling closures, but since we aren't executing
# them we can treat this as a no-op
pass
elif opname == 'KW_NAMES':
tup = co.co_consts[arg]
state.kw_names = tup
elif opname == 'RESUME':
# RESUME is a no-op
pass
elif opname == 'PUSH_NULL':
# We're treating this as a no-op to avoid having to check
# for extra None values on the stack when we extract return
# values
pass
elif opname == 'PRECALL':
# PRECALL is a no-op.
pass
elif opname == 'MAKE_CELL':
# TODO: see if we need to implement cells like this
pass
elif opname == 'RETURN_GENERATOR':
# TODO: see what this behavior is supposed to be beyond
# putting something on the stack to be popped off
state.stack.append(None)
pass
elif opname == 'CACHE':
# No-op introduced in 3.11. Without handling this some
# instructions have functionally > 2 byte size.
pass
elif opname == 'RETURN_CONST':
# Introduced in 3.12. Handles returning constants directly
# instead of having a LOAD_CONST before a RETURN_VALUE.
returns.add(state.const_type(arg))
state = None
elif opname == 'CALL_INTRINSIC_1':
# Introduced in 3.12. The arg is an index into a table of
# operations reproduced in INT_ONE_OPS. Not all ops are
# relevant for our type checking infrastructure.
int_op = intrinsic_one_ops.INT_ONE_OPS[arg]
if debug:
print("Executing intrinsic one op", int_op.__name__.upper())
int_op(state, arg)
else:
raise TypeInferenceError('unable to handle %s' % opname)
if jmp is not None:
# TODO(robertwb): Is this guaranteed to converge?
new_state = states[jmp] | jmp_state
if jmp < pc and new_state != states[jmp] and jumps[pc] < 5:
jumps[pc] += 1
pc = jmp
states[jmp] = new_state
if debug:
print()
print(state)
pprint.pprint(dict(item for item in states.items() if item[1]))
if yields:
result = typehints.Iterable[reduce(union, Const.unwrap_all(yields))]
else:
result = reduce(union, Const.unwrap_all(returns))
finalize_hints(result)
if debug:
print(f, id(f), input_types, '->', result)
return result