Status ValidateStridedSliceOp()

in tensorflow/tensorflow/core/util/strided_slice_op.cc [152:373]


Status ValidateStridedSliceOp(
    const Tensor* begin_tensor, const Tensor* end_tensor,
    const Tensor& strides_tensor, const PartialTensorShape& input_shape,
    int32 begin_mask_spec, int32 end_mask_spec, const int32 ellipsis_mask,
    int32 new_axis_mask, int32 shrink_axis_mask,
    PartialTensorShape* processing_shape, PartialTensorShape* final_shape,
    bool* is_identity, bool* is_simple_slice, bool* slice_dim0,
    gtl::InlinedVector<int64, 4>* begin, gtl::InlinedVector<int64, 4>* end,
    gtl::InlinedVector<int64, 4>* strides) {
  const bool begin_is_wrong =
      begin_tensor != nullptr &&
      !(TensorShapeUtils::IsVector(begin_tensor->shape()) &&
        begin_tensor->NumElements() == strides_tensor.NumElements() &&
        begin_tensor->NumElements() < 32 /* using 32 bit masks */);
  const bool end_is_wrong =
      end_tensor != nullptr &&
      !(TensorShapeUtils::IsVector(end_tensor->shape()) &&
        end_tensor->NumElements() == strides_tensor.NumElements());
  if (begin_is_wrong || end_is_wrong ||
      !TensorShapeUtils::IsVector(strides_tensor.shape())) {
    if (begin_tensor != nullptr && end_tensor != nullptr) {
      return errors::InvalidArgument(
          "Expected begin, end, and strides to be 1D equal size tensors, ",
          "but got shapes ", begin_tensor->shape().DebugString(), ", ",
          end_tensor->shape().DebugString(), ", and ",
          strides_tensor.shape().DebugString(), " instead.");
    } else {
      return errors::InvalidArgument(
          "Expected begin, end, and strides to be 1D equal size tensors, ",
          "but got shape ", strides_tensor.shape().DebugString(),
          " for strides.");
    }
  }
  // Use bit compares to ensure ellipsis_mask is 0 or a power of 2
  // i.e. there exists only no more than one ellipsis
  if (ellipsis_mask && ((ellipsis_mask & (ellipsis_mask - 1)) != 0)) {
    return errors::InvalidArgument(
        "Multiple ellipses in slice spec not allowed");
  }

  // Step 1: Account for ellipsis and new axis
  //
  // Check for ellipses and count how many non-newaxis' there are after
  // TODO(aselle): Convert this to do a fast log2 followed by iteration
  //               counting ones in next guys
  bool ellipsis_seen = false;

  StridedSliceSparseSpec sparse_spec = {strides_tensor.NumElements(),
                                        0,
                                        begin_tensor,
                                        end_tensor,
                                        strides_tensor,
                                        begin_mask_spec,
                                        end_mask_spec,
                                        ellipsis_mask,
                                        new_axis_mask,
                                        shrink_axis_mask};

  for (int32 i = 0; i < sparse_spec.dims; i++) {
    if (ellipsis_seen && ((1 << i) & new_axis_mask) != 0) {
      sparse_spec.num_add_axis_after_ellipsis++;
    }
    if ((1 << i) & ellipsis_mask) {
      ellipsis_seen = true;
    }
  }
  // If no ellipsis insert one at the end
  if (!ellipsis_seen) {
    sparse_spec.ellipsis_mask |= (1 << sparse_spec.dims);
    sparse_spec.dims++;  // this effects loop iteration below
  }

  // Step 2: Make a sparse spec into a full index spec
  //
  // The sparse spec does not correspond to the number of dimensions
  // Make a dense spec that corresponds to the number of dimensions
  //
  // For example suppose foo[...,3:] on foo.shape=(2,2,3) then
  // we need to produce the missing begin_mask for the first two
  // dimensions i.e. from begin_mask_spec=0, end_mask_spec=2
  // we achieve begin_mask=6, end_mask=7
  StridedSliceDenseSpec dense_spec = {input_shape.dims(),
                                      0 /* begin_mask */,
                                      0 /* end_mask */,
                                      false /* begin_valid */,
                                      false /* end_valid */,
                                      *begin,
                                      *end,
                                      *strides};

  if (strides_tensor.dtype() == DT_INT32) {
    TF_RETURN_IF_ERROR(BuildDenseSpec<int32>(sparse_spec, &dense_spec));
  } else if (strides_tensor.dtype() == DT_INT64) {
    TF_RETURN_IF_ERROR(BuildDenseSpec<int64>(sparse_spec, &dense_spec));
  } else {
    LOG(FATAL) << "begin must be either int32 or int64";
  }

  // Step 3: Make implicit ranges (non-zero begin_masks and end_masks) explicit
  //         and bounds check!
  *is_identity = true;
  *slice_dim0 = true;
  *is_simple_slice = true;
  processing_shape->Clear();
  for (int i = 0; i < input_shape.dims(); ++i) {
    int64& begin_i = (*begin)[i];
    int64& end_i = (*end)[i];
    int64& stride_i = (*strides)[i];
    int64 dim_i = input_shape.dim_size(i);
    if (stride_i == 0) {
      return errors::InvalidArgument("strides[", i, "] must be non-zero");
    }
    bool shrink_i = (dense_spec.shrink_axis_mask & (1 << i));
    if (dim_i == -1) {
      processing_shape->AddDim(shrink_i ? 1 : -1);
      continue;
    }

    const std::array<int64, 2> masks = {
        {dense_spec.begin_mask & (1 << i), dense_spec.end_mask & (1 << i)}};
    const std::array<int64, 2> valid_range = {
        {stride_i > 0 ? 0 : -1, stride_i > 0 ? dim_i : dim_i - 1}};

    auto canonical = [stride_i, dim_i, masks, valid_range](int64 x, int c) {
      if (masks[c]) {
        return stride_i > 0 ? valid_range[c] : valid_range[(c + 1) & 1];
      } else {
        int64 x_fwd = x < 0 ? dim_i + x : x;  // make negative indices positive
        return x_fwd < valid_range[0]
                   ? valid_range[0]
                   : x_fwd > valid_range[1] ? valid_range[1] : x_fwd;
      }
    };
    if (shrink_i && stride_i <= 0) {
      return errors::InvalidArgument(
          "only stride 1 allowed on non-range indexing.");
    }
    (*is_simple_slice) &= stride_i == 1;

    const bool begin_and_end_masked =
        (dense_spec.begin_mask & (1 << i)) && (dense_spec.end_mask & (1 << i));
    if (dense_spec.begin_valid && dense_spec.end_valid) {
      if (shrink_i) {
        // If we are shrinking, the end index is now possibly incorrect. In
        // particular foo[-1] produces sparse_begin = -1, sparse_end = 0.
        // and canonical puts these to n-1 and 0, which implies a degenerate
        // interval. Fortunately, it is now safe to re-create end as begin+1.
        int64 x_fwd = begin_i < 0 ? dim_i + begin_i : begin_i;
        begin_i = x_fwd;
        end_i = begin_i + 1;
        if (x_fwd < 0 || x_fwd >= dim_i) {
          return errors::InvalidArgument(
              "slice index ", begin_i, " of dimension ", i, " out of bounds.");
        }
      } else {
        begin_i = canonical(begin_i, 0);
        end_i = canonical(end_i, 1);
      }
      // Update optimization values
      bool take_all_in_dimension =
          stride_i == 1 && begin_i == 0 && end_i == dim_i;
      (*is_identity) &= take_all_in_dimension;
      (*slice_dim0) &= (i == 0 && stride_i == 1) || take_all_in_dimension;
    } else {
      (*is_identity) &= stride_i == 1 && begin_and_end_masked;
      (*slice_dim0) &= (i == 0 && stride_i == 1) || begin_and_end_masked;
    }
    // Compute the processing shape (the intermediate Eigen will produce)
    int64 interval_length;
    bool known_interval = false;
    if (dense_spec.begin_valid && dense_spec.end_valid) {
      interval_length = end_i - begin_i;
      known_interval = true;
    } else if (shrink_i) {
      // The dimension is still known as 1 for the processing_shape, but will be
      // discarded for the final shape.
      interval_length = 1;
      known_interval = true;
    } else if (begin_and_end_masked) {
      // Even if we don't have values for begin or end, we do know that this
      // dimension covers the whole interval. If we have shape information for
      // this dimension, that tells us the interval length.
      if (dim_i >= 0) {
        if (stride_i < 0) {
          interval_length = -dim_i;
        } else {
          interval_length = dim_i;
        }
        known_interval = true;
      }
    }
    if (known_interval) {
      int64 size_i;
      // Hold zero if the interval is degenerate, otherwise account for
      // remainder
      if (interval_length == 0 || ((interval_length < 0) != (stride_i < 0))) {
        size_i = 0;
      } else {
        size_i = interval_length / stride_i +
                 (interval_length % stride_i != 0 ? 1 : 0);
      }
      processing_shape->AddDim(size_i);
    } else {
      processing_shape->AddDim(-1);
    }
  }

  // Step 4: Compute the final shape
  //
  // new_axis will increase dimension by 1 (with a one-size dimension)
  // slices like foo[3,...] will reduce dimension by 1.
  // This cannot be done earlier, because it depends on Step 3.
  final_shape->Clear();
  for (auto gather_index : dense_spec.final_shape_gather_indices) {
    if (gather_index >= 0) {
      final_shape->AddDim(processing_shape->dim_size(gather_index));
    } else if (gather_index == kNewAxis) {
      final_shape->AddDim(1);
    }
  }
  return Status::OK();
}