in tensorflow/tensorflow/core/util/strided_slice_op.cc [152:373]
Status ValidateStridedSliceOp(
const Tensor* begin_tensor, const Tensor* end_tensor,
const Tensor& strides_tensor, const PartialTensorShape& input_shape,
int32 begin_mask_spec, int32 end_mask_spec, const int32 ellipsis_mask,
int32 new_axis_mask, int32 shrink_axis_mask,
PartialTensorShape* processing_shape, PartialTensorShape* final_shape,
bool* is_identity, bool* is_simple_slice, bool* slice_dim0,
gtl::InlinedVector<int64, 4>* begin, gtl::InlinedVector<int64, 4>* end,
gtl::InlinedVector<int64, 4>* strides) {
const bool begin_is_wrong =
begin_tensor != nullptr &&
!(TensorShapeUtils::IsVector(begin_tensor->shape()) &&
begin_tensor->NumElements() == strides_tensor.NumElements() &&
begin_tensor->NumElements() < 32 /* using 32 bit masks */);
const bool end_is_wrong =
end_tensor != nullptr &&
!(TensorShapeUtils::IsVector(end_tensor->shape()) &&
end_tensor->NumElements() == strides_tensor.NumElements());
if (begin_is_wrong || end_is_wrong ||
!TensorShapeUtils::IsVector(strides_tensor.shape())) {
if (begin_tensor != nullptr && end_tensor != nullptr) {
return errors::InvalidArgument(
"Expected begin, end, and strides to be 1D equal size tensors, ",
"but got shapes ", begin_tensor->shape().DebugString(), ", ",
end_tensor->shape().DebugString(), ", and ",
strides_tensor.shape().DebugString(), " instead.");
} else {
return errors::InvalidArgument(
"Expected begin, end, and strides to be 1D equal size tensors, ",
"but got shape ", strides_tensor.shape().DebugString(),
" for strides.");
}
}
// Use bit compares to ensure ellipsis_mask is 0 or a power of 2
// i.e. there exists only no more than one ellipsis
if (ellipsis_mask && ((ellipsis_mask & (ellipsis_mask - 1)) != 0)) {
return errors::InvalidArgument(
"Multiple ellipses in slice spec not allowed");
}
// Step 1: Account for ellipsis and new axis
//
// Check for ellipses and count how many non-newaxis' there are after
// TODO(aselle): Convert this to do a fast log2 followed by iteration
// counting ones in next guys
bool ellipsis_seen = false;
StridedSliceSparseSpec sparse_spec = {strides_tensor.NumElements(),
0,
begin_tensor,
end_tensor,
strides_tensor,
begin_mask_spec,
end_mask_spec,
ellipsis_mask,
new_axis_mask,
shrink_axis_mask};
for (int32 i = 0; i < sparse_spec.dims; i++) {
if (ellipsis_seen && ((1 << i) & new_axis_mask) != 0) {
sparse_spec.num_add_axis_after_ellipsis++;
}
if ((1 << i) & ellipsis_mask) {
ellipsis_seen = true;
}
}
// If no ellipsis insert one at the end
if (!ellipsis_seen) {
sparse_spec.ellipsis_mask |= (1 << sparse_spec.dims);
sparse_spec.dims++; // this effects loop iteration below
}
// Step 2: Make a sparse spec into a full index spec
//
// The sparse spec does not correspond to the number of dimensions
// Make a dense spec that corresponds to the number of dimensions
//
// For example suppose foo[...,3:] on foo.shape=(2,2,3) then
// we need to produce the missing begin_mask for the first two
// dimensions i.e. from begin_mask_spec=0, end_mask_spec=2
// we achieve begin_mask=6, end_mask=7
StridedSliceDenseSpec dense_spec = {input_shape.dims(),
0 /* begin_mask */,
0 /* end_mask */,
false /* begin_valid */,
false /* end_valid */,
*begin,
*end,
*strides};
if (strides_tensor.dtype() == DT_INT32) {
TF_RETURN_IF_ERROR(BuildDenseSpec<int32>(sparse_spec, &dense_spec));
} else if (strides_tensor.dtype() == DT_INT64) {
TF_RETURN_IF_ERROR(BuildDenseSpec<int64>(sparse_spec, &dense_spec));
} else {
LOG(FATAL) << "begin must be either int32 or int64";
}
// Step 3: Make implicit ranges (non-zero begin_masks and end_masks) explicit
// and bounds check!
*is_identity = true;
*slice_dim0 = true;
*is_simple_slice = true;
processing_shape->Clear();
for (int i = 0; i < input_shape.dims(); ++i) {
int64& begin_i = (*begin)[i];
int64& end_i = (*end)[i];
int64& stride_i = (*strides)[i];
int64 dim_i = input_shape.dim_size(i);
if (stride_i == 0) {
return errors::InvalidArgument("strides[", i, "] must be non-zero");
}
bool shrink_i = (dense_spec.shrink_axis_mask & (1 << i));
if (dim_i == -1) {
processing_shape->AddDim(shrink_i ? 1 : -1);
continue;
}
const std::array<int64, 2> masks = {
{dense_spec.begin_mask & (1 << i), dense_spec.end_mask & (1 << i)}};
const std::array<int64, 2> valid_range = {
{stride_i > 0 ? 0 : -1, stride_i > 0 ? dim_i : dim_i - 1}};
auto canonical = [stride_i, dim_i, masks, valid_range](int64 x, int c) {
if (masks[c]) {
return stride_i > 0 ? valid_range[c] : valid_range[(c + 1) & 1];
} else {
int64 x_fwd = x < 0 ? dim_i + x : x; // make negative indices positive
return x_fwd < valid_range[0]
? valid_range[0]
: x_fwd > valid_range[1] ? valid_range[1] : x_fwd;
}
};
if (shrink_i && stride_i <= 0) {
return errors::InvalidArgument(
"only stride 1 allowed on non-range indexing.");
}
(*is_simple_slice) &= stride_i == 1;
const bool begin_and_end_masked =
(dense_spec.begin_mask & (1 << i)) && (dense_spec.end_mask & (1 << i));
if (dense_spec.begin_valid && dense_spec.end_valid) {
if (shrink_i) {
// If we are shrinking, the end index is now possibly incorrect. In
// particular foo[-1] produces sparse_begin = -1, sparse_end = 0.
// and canonical puts these to n-1 and 0, which implies a degenerate
// interval. Fortunately, it is now safe to re-create end as begin+1.
int64 x_fwd = begin_i < 0 ? dim_i + begin_i : begin_i;
begin_i = x_fwd;
end_i = begin_i + 1;
if (x_fwd < 0 || x_fwd >= dim_i) {
return errors::InvalidArgument(
"slice index ", begin_i, " of dimension ", i, " out of bounds.");
}
} else {
begin_i = canonical(begin_i, 0);
end_i = canonical(end_i, 1);
}
// Update optimization values
bool take_all_in_dimension =
stride_i == 1 && begin_i == 0 && end_i == dim_i;
(*is_identity) &= take_all_in_dimension;
(*slice_dim0) &= (i == 0 && stride_i == 1) || take_all_in_dimension;
} else {
(*is_identity) &= stride_i == 1 && begin_and_end_masked;
(*slice_dim0) &= (i == 0 && stride_i == 1) || begin_and_end_masked;
}
// Compute the processing shape (the intermediate Eigen will produce)
int64 interval_length;
bool known_interval = false;
if (dense_spec.begin_valid && dense_spec.end_valid) {
interval_length = end_i - begin_i;
known_interval = true;
} else if (shrink_i) {
// The dimension is still known as 1 for the processing_shape, but will be
// discarded for the final shape.
interval_length = 1;
known_interval = true;
} else if (begin_and_end_masked) {
// Even if we don't have values for begin or end, we do know that this
// dimension covers the whole interval. If we have shape information for
// this dimension, that tells us the interval length.
if (dim_i >= 0) {
if (stride_i < 0) {
interval_length = -dim_i;
} else {
interval_length = dim_i;
}
known_interval = true;
}
}
if (known_interval) {
int64 size_i;
// Hold zero if the interval is degenerate, otherwise account for
// remainder
if (interval_length == 0 || ((interval_length < 0) != (stride_i < 0))) {
size_i = 0;
} else {
size_i = interval_length / stride_i +
(interval_length % stride_i != 0 ? 1 : 0);
}
processing_shape->AddDim(size_i);
} else {
processing_shape->AddDim(-1);
}
}
// Step 4: Compute the final shape
//
// new_axis will increase dimension by 1 (with a one-size dimension)
// slices like foo[3,...] will reduce dimension by 1.
// This cannot be done earlier, because it depends on Step 3.
final_shape->Clear();
for (auto gather_index : dense_spec.final_shape_gather_indices) {
if (gather_index >= 0) {
final_shape->AddDim(processing_shape->dim_size(gather_index));
} else if (gather_index == kNewAxis) {
final_shape->AddDim(1);
}
}
return Status::OK();
}