def StridedSlice()

in coremltools/converters/mil/frontend/tensorflow/ops.py [0:0]


def StridedSlice(context, node):
    x = context[node.inputs[0]]
    begin = context[node.inputs[1]]
    end = context[node.inputs[2]]
    stride = context[node.inputs[3]]

    def bitmask_to_array(bit):
        arr = []
        while bit > 0:
            if bit & 1:
                arr.append(True)
            else:
                arr.append(False)
            bit >>= 1
        return arr

    begin_mask = bitmask_to_array(node.attr.get("begin_mask", 0))
    end_mask = bitmask_to_array(node.attr.get("end_mask", 0))
    squeeze_mask = bitmask_to_array(node.attr.get("shrink_axis_mask", 0))
    ellipsis_mask = bitmask_to_array(node.attr.get("ellipsis_mask", 0))
    new_axis_mask = bitmask_to_array(node.attr.get("new_axis_mask", 0))

    def _pad_mask(
        x,
        begin,
        end,
        stride,
        begin_mask,
        end_mask,
        squeeze_mask,
        ellipsis_mask,
        new_axis_mask,
    ):
        # This function pad the masks, stride, begin and end to the same rank as the input tensor.
        if begin.rank != 1:
            raise ValueError(
                "begin should be 1-D tensor, got {}-D tensor instead".format(begin.rank)
            )
        if end.rank != 1:
            raise ValueError(
                "end should be 1-D tensor, got {}-D tensor instead".format(end.rank)
            )

        # check if inputs can be determined
        begin_cache = begin
        end_cache = end
        begin = [] if begin.val is None else begin.val.tolist()
        end = [] if end.val is None else end.val.tolist()
        stride = [] if stride is None else stride.val.tolist()

        # pad masks function
        new_dims = sum(i == True for i in new_axis_mask)
        if new_dims > 0:
            x_rank = x.rank + new_dims
        else:
            x_rank = x.rank

        def pad_array(arr, max_rank, idx, default_value):
            """
            This function pads the arr to x_rank with default_value.
            idx is the index where ellipis_mask = True.
            max_rank is the maximum rank of the masks, stride, begin and end.
            """
            mask = arr[:]
            mask += [default_value] * (x_rank - len(mask))
            new_mask = []

            for i in range(max_rank):
                num = 1 if i != idx else x_rank - max_rank + 1
                new_mask += [mask[i]] * num
            return new_mask

        mask_list = [
            begin_mask,
            end_mask,
            squeeze_mask,
            ellipsis_mask,
            new_axis_mask,
            stride,
            begin,
            end,
        ]
        max_rank = max([len(arr) for arr in mask_list])

        # If ellipsis_mask is given, the last element of it would be True
        # Otherwise, we simply pad each mask by appending default value
        if ellipsis_mask != []:
            rank = max_rank
            idx = len(ellipsis_mask) - 1
        else:
            rank = x_rank
            idx = -1

        begin_mask = pad_array(begin_mask, rank, idx, False)
        end_mask = pad_array(end_mask, rank, idx, False)
        squeeze_mask = pad_array(squeeze_mask, rank, idx, False)
        ellipsis_mask = pad_array(ellipsis_mask, rank, idx, False)
        new_axis_mask = pad_array(new_axis_mask, rank, idx, False)
        stride = pad_array(stride, rank, idx, 1)

        # pad begin and end if they are determined during compile time
        if begin != []:
            begin = pad_array(begin, rank, idx, 0)
        if end != []:
            end = pad_array(end, rank, idx, 0)

        # make sure begin_mask, end_mask, and stride are consistent with ellipsis mask
        # begin_mask and end_mask should be True, and stride should be 1.
        for i, mask in enumerate(ellipsis_mask):
            if mask:
                begin_mask[i] = True
                end_mask[i] = True
                stride[i] = 1

        # make sure begin_mask, end_mask, and stride are consistent with new axis mask
        # begin_mask and end_mask should be True, and stride should be 1.
        for i, mask in enumerate(new_axis_mask):
            if mask:
                begin_mask[i] = True
                end_mask[i] = True
                stride[i] = 1

        # convert begin and end back to cache value if they are run-time determined
        if begin == []:
            begin = begin_cache

        if end == []:
            end = end_cache

        # check which mask is adding by our default value
        # This happens when the given index is less than the tensor rank,
        # for instance, indexing a 3D tensor A with A[:1, :1] is equivalent to
        # A[:1, :1, :]. In this case we should append True to begin_mask and end_mask
        if ellipsis_mask == [False] * x_rank:
            for i in range(max_rank, x_rank):
                begin_mask[i] = True
                end_mask[i] = True

        return begin, end, stride, begin_mask, end_mask, squeeze_mask, new_axis_mask

    begin, end, stride, begin_mask, end_mask, squeeze_mask, new_axis_mask = _pad_mask(
        x,
        begin,
        end,
        stride,
        begin_mask,
        end_mask,
        squeeze_mask,
        ellipsis_mask,
        new_axis_mask,
    )

    if sum(i == True for i in new_axis_mask) > 0:
        axes = [i for i, val in enumerate(new_axis_mask) if val == True]
        x = mb.expand_dims(x=x, axes=axes, name=node.name + "_new_axes")

    x = mb.slice_by_index(
        x=x,
        name=node.name,
        begin=begin,
        end=end,
        stride=stride,
        begin_mask=begin_mask,
        end_mask=end_mask,
        squeeze_mask=squeeze_mask,
    )

    context.add(node.name, x)