in coremltools/converters/mil/frontend/tensorflow/ops.py [0:0]
def StridedSlice(context, node):
x = context[node.inputs[0]]
begin = context[node.inputs[1]]
end = context[node.inputs[2]]
stride = context[node.inputs[3]]
def bitmask_to_array(bit):
arr = []
while bit > 0:
if bit & 1:
arr.append(True)
else:
arr.append(False)
bit >>= 1
return arr
begin_mask = bitmask_to_array(node.attr.get("begin_mask", 0))
end_mask = bitmask_to_array(node.attr.get("end_mask", 0))
squeeze_mask = bitmask_to_array(node.attr.get("shrink_axis_mask", 0))
ellipsis_mask = bitmask_to_array(node.attr.get("ellipsis_mask", 0))
new_axis_mask = bitmask_to_array(node.attr.get("new_axis_mask", 0))
def _pad_mask(
x,
begin,
end,
stride,
begin_mask,
end_mask,
squeeze_mask,
ellipsis_mask,
new_axis_mask,
):
# This function pad the masks, stride, begin and end to the same rank as the input tensor.
if begin.rank != 1:
raise ValueError(
"begin should be 1-D tensor, got {}-D tensor instead".format(begin.rank)
)
if end.rank != 1:
raise ValueError(
"end should be 1-D tensor, got {}-D tensor instead".format(end.rank)
)
# check if inputs can be determined
begin_cache = begin
end_cache = end
begin = [] if begin.val is None else begin.val.tolist()
end = [] if end.val is None else end.val.tolist()
stride = [] if stride is None else stride.val.tolist()
# pad masks function
new_dims = sum(i == True for i in new_axis_mask)
if new_dims > 0:
x_rank = x.rank + new_dims
else:
x_rank = x.rank
def pad_array(arr, max_rank, idx, default_value):
"""
This function pads the arr to x_rank with default_value.
idx is the index where ellipis_mask = True.
max_rank is the maximum rank of the masks, stride, begin and end.
"""
mask = arr[:]
mask += [default_value] * (x_rank - len(mask))
new_mask = []
for i in range(max_rank):
num = 1 if i != idx else x_rank - max_rank + 1
new_mask += [mask[i]] * num
return new_mask
mask_list = [
begin_mask,
end_mask,
squeeze_mask,
ellipsis_mask,
new_axis_mask,
stride,
begin,
end,
]
max_rank = max([len(arr) for arr in mask_list])
# If ellipsis_mask is given, the last element of it would be True
# Otherwise, we simply pad each mask by appending default value
if ellipsis_mask != []:
rank = max_rank
idx = len(ellipsis_mask) - 1
else:
rank = x_rank
idx = -1
begin_mask = pad_array(begin_mask, rank, idx, False)
end_mask = pad_array(end_mask, rank, idx, False)
squeeze_mask = pad_array(squeeze_mask, rank, idx, False)
ellipsis_mask = pad_array(ellipsis_mask, rank, idx, False)
new_axis_mask = pad_array(new_axis_mask, rank, idx, False)
stride = pad_array(stride, rank, idx, 1)
# pad begin and end if they are determined during compile time
if begin != []:
begin = pad_array(begin, rank, idx, 0)
if end != []:
end = pad_array(end, rank, idx, 0)
# make sure begin_mask, end_mask, and stride are consistent with ellipsis mask
# begin_mask and end_mask should be True, and stride should be 1.
for i, mask in enumerate(ellipsis_mask):
if mask:
begin_mask[i] = True
end_mask[i] = True
stride[i] = 1
# make sure begin_mask, end_mask, and stride are consistent with new axis mask
# begin_mask and end_mask should be True, and stride should be 1.
for i, mask in enumerate(new_axis_mask):
if mask:
begin_mask[i] = True
end_mask[i] = True
stride[i] = 1
# convert begin and end back to cache value if they are run-time determined
if begin == []:
begin = begin_cache
if end == []:
end = end_cache
# check which mask is adding by our default value
# This happens when the given index is less than the tensor rank,
# for instance, indexing a 3D tensor A with A[:1, :1] is equivalent to
# A[:1, :1, :]. In this case we should append True to begin_mask and end_mask
if ellipsis_mask == [False] * x_rank:
for i in range(max_rank, x_rank):
begin_mask[i] = True
end_mask[i] = True
return begin, end, stride, begin_mask, end_mask, squeeze_mask, new_axis_mask
begin, end, stride, begin_mask, end_mask, squeeze_mask, new_axis_mask = _pad_mask(
x,
begin,
end,
stride,
begin_mask,
end_mask,
squeeze_mask,
ellipsis_mask,
new_axis_mask,
)
if sum(i == True for i in new_axis_mask) > 0:
axes = [i for i, val in enumerate(new_axis_mask) if val == True]
x = mb.expand_dims(x=x, axes=axes, name=node.name + "_new_axes")
x = mb.slice_by_index(
x=x,
name=node.name,
begin=begin,
end=end,
stride=stride,
begin_mask=begin_mask,
end_mask=end_mask,
squeeze_mask=squeeze_mask,
)
context.add(node.name, x)