in tensorflow/tensorflow/python/framework/op_def_library.py [0:0]
def _apply_op_helper(self, op_type_name, name=None, **keywords):
"""Implementation of apply_op that returns output_structure, op."""
op_info = self._ops.get(op_type_name, None)
if op_info is None:
raise RuntimeError("Unrecognized Op name " + op_type_name)
op_def = op_info.op_def
# Determine the graph context.
try:
# Need to flatten all the arguments into a list.
# pylint: disable=protected-access
g = ops._get_graph_from_inputs(_Flatten(keywords.values()))
# pylint: enable=protected-access
except AssertionError as e:
raise RuntimeError(
"Cannot determine graph for Op '%s' due to: %s"
% (op_type_name, e.message))
# Default name if not specified.
if name is None:
name = op_type_name
# Check for deprecation
deprecation_version = op_def.deprecation.version
if deprecation_version:
producer = g.graph_def_versions.producer
if producer >= deprecation_version:
raise NotImplementedError(
("Op %s is not available in GraphDef version %d. "
"It has been removed in version %d. %s.") %
(op_type_name, producer, deprecation_version,
op_def.deprecation.explanation))
# Fill in the list of default types for all "type" attrs. This
# will be used to choose a preferred dtype to convert to in the
# absence of input type information.
#
# TODO(b/31302892): Currently the defaults don't work in the right
# way if you have two inputs, one of whose type resolution depends
# on the other. Handling this will require restructuring this code
# significantly.
default_type_attr_map = {}
for attr_def in op_def.attr:
if attr_def.type != "type":
continue
key = attr_def.name
if attr_def.HasField("default_value"):
default_type_attr_map[key] = dtypes.as_dtype(
attr_def.default_value.type)
# Requires that op_def has passed validation (using the C++
# ValidateOpDef() from ../framework/op_def_util.h).
attrs = {}
inputs = []
input_types = []
with g.as_default(), ops.name_scope(name) as scope:
# Perform input type inference
inferred_from = {}
for input_arg in op_def.input_arg:
input_name = input_arg.name
if input_name in keywords:
values = keywords.pop(input_name)
elif input_name + "_" in keywords:
# Handle the case where the name is a keyword or built-in
# for Python so we use the name + _ instead.
input_name += "_"
values = keywords.pop(input_name)
else:
raise TypeError("No argument for input " + input_name)
# Goals:
# * Convert values to Tensors if it contains constants.
# * Verify that values is a list if that matches the input_arg's
# type.
# * If the input_arg's type is determined by attrs, either set
# those attrs and validate those attr values are legal (if
# they have not yet been set) or validate the input matches
# the type indicated by the attrs (if they have already been
# inferred via an earlier input).
# * If the input_arg has an explicit type, make sure the input
# conforms.
if _IsListParameter(input_arg):
if not _IsListValue(values):
raise TypeError(
"Expected list for '%s' argument to '%s' Op, not %s." %
(input_name, op_type_name, values))
# In cases where we expect all elements of the list to have the
# same dtype, try to cast non-Tensor elements to that type.
dtype = None
default_dtype = None
if input_arg.type != types_pb2.DT_INVALID:
dtype = input_arg.type
elif input_arg.number_attr:
if input_arg.type_attr in attrs:
dtype = attrs[input_arg.type_attr]
else:
for t in values:
if isinstance(t, ops.Tensor):
dtype = t.dtype
break
# dtype still not found, prefer using the default dtype
# from the attr.
if dtype is None and input_arg.type_attr in default_type_attr_map:
default_dtype = default_type_attr_map[input_arg.type_attr]
try:
if not input_arg.is_ref and dtype:
dtype = dtypes.as_dtype(dtype).base_dtype
values = ops.internal_convert_n_to_tensor(
values,
name=input_arg.name,
dtype=dtype if dtype else None,
preferred_dtype=default_dtype,
as_ref=input_arg.is_ref)
if input_arg.number_attr and len(
set(v.dtype.base_dtype for v in values)) > 1:
raise TypeError() # All types should match.
except (TypeError, ValueError):
# What types does the conversion function think values have?
observed_types = []
for value in values:
try:
converted_value = ops.internal_convert_to_tensor(
value, as_ref=input_arg.is_ref)
observed_types.append(converted_value.dtype.base_dtype.name)
except (TypeError, ValueError):
observed_types.append("<NOT CONVERTIBLE TO TENSOR>")
observed = ", ".join(observed_types)
prefix = (
"Tensors in list passed to '%s' of '%s' Op have types [%s]" %
(input_name, op_type_name, observed))
if input_arg.number_attr:
if input_arg.type != types_pb2.DT_INVALID:
raise TypeError("%s that do not match expected type %s." %
(prefix, dtype.name))
elif input_arg.type_attr in attrs:
raise TypeError("%s that do not match type %s inferred from "
"earlier arguments." %
(prefix, dtype.name))
else:
raise TypeError("%s that don't all match." % prefix)
else:
raise TypeError(
"%s that are invalid. Tensors: %s" % (prefix, values))
types = [x.dtype for x in values]
inputs.extend(values)
else:
# In cases where we have an expected type, try to convert non-Tensor
# arguments to that type.
dtype = None
default_dtype = None
if input_arg.type != types_pb2.DT_INVALID:
dtype = input_arg.type
elif input_arg.type_attr in attrs:
dtype = attrs[input_arg.type_attr]
elif input_arg.type_attr in default_type_attr_map:
# The dtype could not be inferred solely from the inputs,
# so we prefer the attr's default, so code that adds a new attr
# with a default is backwards compatible.
default_dtype = default_type_attr_map[input_arg.type_attr]
try:
values = ops.internal_convert_to_tensor(
values,
name=input_arg.name,
dtype=dtype,
as_ref=input_arg.is_ref,
preferred_dtype=default_dtype)
except TypeError as err:
if dtype is None:
raise err
else:
raise TypeError(
"Expected %s passed to parameter '%s' of op '%s', got %s of "
"type '%s' instead. Error: %s" %
(dtypes.as_dtype(dtype).name, input_arg.name, op_type_name,
repr(values), type(values).__name__, err))
except ValueError:
# What type does convert_to_tensor think it has?
try:
observed = ops.internal_convert_to_tensor(
values, as_ref=input_arg.is_ref).dtype.name
except ValueError as err:
raise ValueError(
"Tried to convert '%s' to a tensor and failed. Error: %s" %
(input_name, err))
prefix = ("Input '%s' of '%s' Op has type %s that does not match" %
(input_name, op_type_name, observed))
if input_arg.type != types_pb2.DT_INVALID:
raise TypeError("%s expected type of %s." %
(prefix, dtypes.as_dtype(input_arg.type).name))
else:
# Update the maps with the default, if needed.
k = input_arg.type_attr
if k in default_type_attr_map:
if k not in attrs:
attrs[k] = default_type_attr_map[k]
if k not in inferred_from:
inferred_from[k] = "Default in OpDef"
raise TypeError(
"%s type %s of argument '%s'." %
(prefix, dtypes.as_dtype(attrs[input_arg.type_attr]).name,
inferred_from[input_arg.type_attr]))
types = [values.dtype]
inputs.append(values)
base_types = [x.base_dtype for x in types]
if input_arg.number_attr:
# <number-attr> * <type> or <number-attr> * <type-attr>
if input_arg.number_attr in attrs:
if len(values) != attrs[input_arg.number_attr]:
raise ValueError(
"List argument '%s' to '%s' Op with length %d must match "
"length %d of argument '%s'." %
(input_name, op_type_name, len(values),
attrs[input_arg.number_attr],
inferred_from[input_arg.number_attr]))
else:
attrs[input_arg.number_attr] = len(values)
inferred_from[input_arg.number_attr] = input_name
num_attr = _Attr(op_def, input_arg.number_attr)
if num_attr.has_minimum and len(values) < num_attr.minimum:
raise ValueError(
"List argument '%s' to '%s' Op with length %d shorter "
"than minimum length %d." %
(input_name, op_type_name, len(values), num_attr.minimum))
# All tensors must have the same base type.
if any(bt != base_types[0] for bt in base_types):
raise TypeError(
"All tensors passed to '%s' of '%s' Op "
"must have the same type." %
(input_name, op_type_name))
if input_arg.type != types_pb2.DT_INVALID:
# <number-attr> * <type> case
if base_types and base_types[0] != input_arg.type:
assert False, "Unreachable"
elif input_arg.type_attr in attrs:
# <number-attr> * <type-attr> case, where <type-attr> already
# has an inferred value.
if base_types and base_types[0] != attrs[input_arg.type_attr]:
assert False, "Unreachable"
else:
# <number-attr> * <type-attr> case, where we are now setting
# the <type-attr> based on this input
if not base_types:
raise TypeError(
"Don't know how to infer type variable from empty input "
"list passed to input '%s' of '%s' Op." %
(input_name, op_type_name))
attrs[input_arg.type_attr] = base_types[0]
inferred_from[input_arg.type_attr] = input_name
type_attr = _Attr(op_def, input_arg.type_attr)
_SatisfiesTypeConstraint(base_types[0], type_attr,
param_name=input_name)
elif input_arg.type_attr:
# <type-attr>
attr_value = base_types[0]
if input_arg.type_attr in attrs:
if attrs[input_arg.type_attr] != attr_value:
raise TypeError(
"Input '%s' of '%s' Op has type %s that does not "
"match type %s of argument '%s'." %
(input_name, op_type_name, dtypes.as_dtype(attr_value).name,
dtypes.as_dtype(attrs[input_arg.type_attr]).name,
inferred_from[input_arg.type_attr]))
else:
for base_type in base_types:
_SatisfiesTypeConstraint(base_type,
_Attr(op_def, input_arg.type_attr),
param_name=input_name)
attrs[input_arg.type_attr] = attr_value
inferred_from[input_arg.type_attr] = input_name
elif input_arg.type_list_attr:
# <type-list-attr>
attr_value = base_types
if input_arg.type_list_attr in attrs:
if attrs[input_arg.type_list_attr] != attr_value:
raise TypeError(
"Input '%s' of '%s' Op has type list of %s that does not "
"match type list %s of argument '%s'." %
(input_name, op_type_name,
", ".join(dtypes.as_dtype(x).name for x in attr_value),
", ".join(dtypes.as_dtype(x).name
for x in attrs[input_arg.type_list_attr]),
inferred_from[input_arg.type_list_attr]))
else:
for base_type in base_types:
_SatisfiesTypeConstraint(base_type,
_Attr(op_def, input_arg.type_list_attr),
param_name=input_name)
attrs[input_arg.type_list_attr] = attr_value
inferred_from[input_arg.type_list_attr] = input_name
else:
# single Tensor with specified type
if base_types[0] != input_arg.type:
assert False, "Unreachable"
if input_arg.is_ref:
if not all(x._is_ref_dtype for x in types): # pylint: disable=protected-access
raise TypeError(
("'%s' Op requires that input '%s' be a mutable tensor "
"(e.g.: a tf.Variable)") % (op_type_name, input_name))
input_types.extend(types)
else:
input_types.extend(base_types)
# Process remaining attrs
for attr in op_def.attr:
# Skip attrs that have already had their values inferred
if attr.name in attrs:
if attr.name in keywords:
raise TypeError(
"Should not specify value for inferred attr '%s'." % attr.name)
continue
if attr.name in keywords:
attrs[attr.name] = keywords.pop(attr.name)
elif attr.name + "_" in keywords:
# Attrs whose names match Python keywords have an extra '_'
# appended, so we must check for that as well.
attrs[attr.name] = keywords.pop(attr.name + "_")
else:
raise TypeError("No argument for attr " + attr.name)
# Convert attr values to AttrValue protos.
attr_protos = {}
for attr_def in op_def.attr:
key = attr_def.name
value = attrs[key]
attr_value = attr_value_pb2.AttrValue()
if attr_def.HasField("default_value") and value is None:
attr_value.CopyFrom(attr_def.default_value)
attr_protos[key] = attr_value
continue
if attr_def.type.startswith("list("):
if not _IsListValue(value):
raise TypeError("Expected list for attr " + key)
if attr_def.has_minimum:
if len(value) < attr_def.minimum:
raise ValueError("Attr '%s' of '%s' Op passed list of length %d "
"less than minimum %d." %
(key, op_type_name, len(value),
attr_def.minimum))
attr_value.list.SetInParent()
if attr_def.type == "string":
attr_value.s = _MakeStr(value, key)
if attr_def.HasField("allowed_values"):
if attr_value.s not in attr_def.allowed_values.list.s:
raise ValueError(
"Attr '%s' of '%s' Op passed string '%s' not in: \"%s\"." %
(key, op_type_name, compat.as_text(attr_value.s),
'", "'.join(map(compat.as_text,
attr_def.allowed_values.list.s))))
elif attr_def.type == "list(string)":
attr_value.list.s.extend([_MakeStr(x, key) for x in value])
if attr_def.HasField("allowed_values"):
for x in attr_value.list.s:
if x not in attr_def.allowed_values.list.s:
raise ValueError(
"Attr '%s' of '%s' Op passed string '%s' not in: \"%s\"." %
(key, op_type_name, compat.as_text(x),
'", "'.join(map(compat.as_text,
attr_def.allowed_values.list.s))))
elif attr_def.type == "int":
attr_value.i = _MakeInt(value, key)
if attr_def.has_minimum:
if attr_value.i < attr_def.minimum:
raise ValueError(
"Attr '%s' of '%s' Op passed %d less than minimum %d." %
(key, op_type_name, attr_value.i, attr_def.minimum))
elif attr_def.type == "list(int)":
attr_value.list.i.extend([_MakeInt(x, key) for x in value])
elif attr_def.type == "float":
attr_value.f = _MakeFloat(value, key)
elif attr_def.type == "list(float)":
attr_value.list.f.extend([_MakeFloat(x, key) for x in value])
elif attr_def.type == "bool":
attr_value.b = _MakeBool(value, key)
elif attr_def.type == "list(bool)":
attr_value.list.b.extend([_MakeBool(x, key) for x in value])
elif attr_def.type == "type":
attr_value.type = _MakeType(value, attr_def)
elif attr_def.type == "list(type)":
attr_value.list.type.extend(
[_MakeType(x, attr_def) for x in value])
elif attr_def.type == "shape":
attr_value.shape.CopyFrom(_MakeShape(value, key))
elif attr_def.type == "list(shape)":
attr_value.list.shape.extend(
[_MakeShape(x, key) for x in value])
elif attr_def.type == "tensor":
attr_value.tensor.CopyFrom(_MakeTensor(value, key))
elif attr_def.type == "list(tensor)":
attr_value.list.tensor.extend(
[_MakeTensor(x, key) for x in value])
elif attr_def.type == "func":
attr_value.func.CopyFrom(_MakeFunc(value, key))
elif attr_def.type == "list(func)":
attr_value.list.func.extend([_MakeFunc(x, key) for x in value])
else:
raise TypeError("Unrecognized Attr type " + attr_def.type)
attr_protos[key] = attr_value
del attrs # attrs is no longer authoritative, use attr_protos instead
# Determine output types (possibly using attrs)
output_structure = []
for arg in op_def.output_arg:
if arg.number_attr:
n = _AttrValue(attr_protos, arg.number_attr).i
output_structure.append(n)
elif arg.type_attr:
t = _AttrValue(attr_protos, arg.type_attr)
output_structure.append(None)
elif arg.type_list_attr:
t = _AttrValue(attr_protos, arg.type_list_attr)
output_structure.append(len(t.list.type))
else:
output_structure.append(None)
if keywords:
raise TypeError("apply_op() got unexpected keyword arguments: " +
", ".join(sorted(keywords.keys())))
# NOTE(mrry): We add an explicit colocation constraint between
# the newly created op and any of its reference-typed inputs.
must_colocate_inputs = [val for arg, val in zip(op_def.input_arg, inputs)
if arg.is_ref]
with _MaybeColocateWith(must_colocate_inputs):
# Add Op to graph
op = g.create_op(op_type_name, inputs, dtypes=None, name=scope,
input_types=input_types, attrs=attr_protos,
op_def=op_def)
# Conditionally invoke tfdbg v2's op callback(s).
if op_callbacks.should_invoke_op_callbacks():
callback_outputs = op_callbacks.invoke_op_callbacks(
op.node_def.op, tuple(op.inputs), attr_protos, tuple(op.outputs),
op_name=op.name, graph=g)
if callback_outputs is not None:
for slot_index, callback_output in enumerate(callback_outputs):
op.outputs[slot_index] = callback_output
return output_structure, op_def.is_stateful, op