in python/tvm/relay/frontend/tensorflow_ops.py [0:0]
def _stridedSlice():
def _impl(inputs, attr, params, mod):
"""Strided Slice.
Operator description: https://www.tensorflow.org/api_docs/python/tf/strided_slice
Tensorflow mask validation: https://github.com/tensorflow/tensorflow/blob/master/
tensorflow/core/util/strided_slice_op.cc#L147-L368
"""
begin = _get_list_param(params, inputs[1], mod)
end = _get_list_param(params, inputs[2], mod)
stride = _get_list_param(params, inputs[3], mod)
begin_mask = int(attr.get("begin_mask", 0))
end_mask = int(attr.get("end_mask", 0))
ellipsis_mask = int(attr.get("ellipsis_mask", 0))
new_axis_mask = int(attr.get("new_axis_mask", 0))
shrink_axis_mask = int(attr.get("shrink_axis_mask", 0))
in_type = _infer_type(inputs[0], mod)
data_shape = get_const_tuple(in_type.checked_type.shape)
data_dim = len(data_shape)
stride_dim = len(stride)
if data_dim == 0 and isinstance(inputs[0], _expr.Constant):
new_data = inputs[0].data.numpy().reshape(1)
return _expr.const(new_data, inputs[0].data.dtype)
# This is a special routine to handle strided_slice after shape_of.
# We need this since in some cases we want to do strided_slice on
# a partial symbolic shape, such as (1, ?), and get a static shape
# (1,). Directly slice on shape_of will result in fully dynamic shape.
# TODO(kevinthesun): Can we generalize this process with partial eval?
if isinstance(inputs[0], _expr.Call) and inputs[0].op == _op.get("shape_of"):
bg = begin[0]
ed = end[0]
st = stride[0]
if ed <= 0 < st:
ed += data_shape[0]
in_shape = _infer_shape(inputs[0].args[0], mod)
dtype = in_type.checked_type.dtype
out_data = []
idx = bg
while idx < ed:
if isinstance(in_shape[idx], int):
out_data.append(in_shape[idx])
else:
break
idx += st
# Only return when in_shape is fully static in the range from begin to end.
if idx >= ed:
ret = _expr.const(out_data, dtype)
if shrink_axis_mask:
ret = _op.squeeze(ret)
return ret
def _transform_mask(stride_dim, ellipsis_mask):
"""Handle mask inputs to create new begin, end, stride and output shape"""
m_begin = [0] * data_dim
m_end = [0] * data_dim
m_stride = [0] * data_dim
fshape_indices = []
# Count new axis after ellipsis_mask, consider while applying ellipsis_mask.
ellipsis_seen = False
new_axes_after_ellipsis = 0
for i in range(stride_dim):
mask = 1 << i
if ellipsis_seen and (mask & new_axis_mask) != 0:
new_axes_after_ellipsis += 1
if (mask & ellipsis_mask) != 0:
ellipsis_seen = True
if not ellipsis_seen:
# Used later for extending the stride attributes in the below loop.
ellipsis_mask |= 1 << stride_dim
stride_dim += 1
final_index = 0
for index in range(stride_dim):
mask = 1 << index
if mask & ellipsis_mask:
# Identify the end index for applying ellipsis_mask
to_index = min(
((data_dim - (stride_dim - index)) + 1 + new_axes_after_ellipsis), data_dim
)
for i in range(final_index, to_index):
m_begin[final_index] = 0
m_end[final_index] = data_shape[final_index]
m_stride[final_index] = 1
fshape_indices.append(final_index)
final_index += 1
elif mask & new_axis_mask:
fshape_indices.append(-1)
elif not mask & new_axis_mask:
if final_index == len(m_begin):
break
if mask & begin_mask:
m_begin[final_index] = -1 if stride[index] < 0 else 0
elif begin[index]:
m_begin[final_index] = begin[index]
if mask & end_mask:
m_end[final_index] = (
-(data_shape[final_index] + 1)
if stride[index] < 0
else data_shape[final_index]
)
elif end[index]:
m_end[final_index] = end[index]
m_stride[final_index] = stride[index]
if mask & shrink_axis_mask:
# Tensorflow make axis with shrink_axis_mask as dimension 1
m_begin[final_index] = (
data_shape[final_index] + begin[index]
if begin[index] < 0
else begin[index]
)
m_end[final_index] = begin[index] + 1
m_stride[final_index] = 1
fshape_indices.append(-2)
else:
fshape_indices.append(final_index)
final_index += 1
return m_begin, m_end, m_stride, fshape_indices
fshape_indices = None
if begin_mask or end_mask or ellipsis_mask or new_axis_mask or shrink_axis_mask:
begin, end, stride, fshape_indices = _transform_mask(stride_dim, ellipsis_mask)
out = _op.strided_slice(inputs[0], begin=begin, end=end, strides=stride)
out_shape = _infer_shape(out, mod=mod)
if not fshape_indices:
fshape_indices = range(len(out_shape))
# Create final output shape.
final_output = []
for gather_index in fshape_indices:
if gather_index == -1:
final_output.append(1)
elif gather_index == -2:
pass
else:
final_output.append(out_shape[gather_index])
if not final_output:
if not shrink_axis_mask:
ret = out
else:
final_shape = []
for dim in out_shape:
if dim != 1:
final_shape.append(dim)
if len(final_shape) == 0:
ret = _op.squeeze(out)
else:
# We need reshape to handle dynamic shape.
ret = _op.reshape(out, newshape=tuple(final_shape))
else:
ret = _op.reshape(out, newshape=tuple(final_output))
return ret
return _impl