in torch_xla/csrc/pooling.cpp [252:347]
xla::XlaOp ComputeMaxPoolIndices(
const xla::Shape& input_shape, xla::XlaOp padded_input,
xla::XlaOp pool_result, const xla::PaddingConfig& padding_config,
const PoolingOpAttributes& pooling_op_attributes) {
if (!IsOverlapping(pooling_op_attributes)) {
// The algorithm in ComputeNoOverlapMaxPoolIndices() only works if reduce
// windows do not overlap. If they do, the reduce-window done on the indices
// will find multiple indices within the window, and won't know what to
// select. If XLA had a vardic reduce-window we could do that.
return ComputeNoOverlapMaxPoolIndices(input_shape, padded_input,
pool_result, padding_config,
pooling_op_attributes);
}
// Slow version follows.
// We loop through every window and compute the index. The slow code will only
// be executed if the caller actually uses the indices, and only if the reduce
// windows overlap.
xla::XlaOp iota = CreatePoolIndicesIota(input_shape, padded_input.builder());
xla::XlaOp padded_iota =
xla::Pad(iota, xla::MaxValue(padded_input.builder(), kIndicesType),
padding_config);
const xla::Shape& pool_result_shape = XlaHelpers::ShapeOfXlaOp(pool_result);
xla::int64_t pool_elements = xla::ShapeUtil::ElementsIn(pool_result_shape);
InitValues initial_values;
size_t counter_id =
initial_values.append(xla::Zero(padded_input.builder(), kIndicesType));
size_t limit_id = initial_values.append(XlaHelpers::ScalarValue(
pool_elements, kIndicesType, padded_input.builder()));
size_t input_id = initial_values.append(padded_input);
size_t pool_result_id = initial_values.append(pool_result);
size_t iota_id = initial_values.append(padded_iota);
size_t result_id = initial_values.append(
xla::Zeros(padded_input.builder(),
xla::ShapeUtil::MakeShape(kIndicesType, {pool_elements})));
auto cond_fn = [&](absl::Span<const xla::XlaOp> init,
xla::XlaBuilder* builder) -> xla::StatusOr<xla::XlaOp> {
return xla::Lt(init[counter_id], init[limit_id]);
};
auto body_fn =
[&](absl::Span<const xla::XlaOp> init,
xla::XlaBuilder* builder) -> xla::StatusOr<std::vector<xla::XlaOp>> {
PoolSliceIndices slice_indices =
ComputeSliceIndices(init[counter_id], pool_result_shape.dimensions(),
pooling_op_attributes.stride);
xla::XlaOp input_slice =
xla::DynamicSlice(init[input_id], slice_indices.input_indices,
pooling_op_attributes.kernel_size);
xla::XlaOp iota_slice =
xla::DynamicSlice(init[iota_id], slice_indices.input_indices,
pooling_op_attributes.kernel_size);
std::vector<xla::int64_t> result_slice_sizes(
pooling_op_attributes.kernel_size.size(), 1);
xla::XlaOp pool_result_slice = xla::DynamicSlice(
init[pool_result_id], slice_indices.result_indices, result_slice_sizes);
xla::XlaComputation select =
xla::CreateScalarGeComputation(input_shape.element_type(), builder);
xla::XlaComputation scatter =
xla::CreateScalarMaxComputation(input_shape.element_type(), builder);
xla::XlaOp init_value = xla::MinValue(builder, input_shape.element_type());
xla::XlaOp scattered_pool = xla::SelectAndScatter(
input_slice, select, pooling_op_attributes.kernel_size,
pooling_op_attributes.stride, xla::Padding::kValid, pool_result_slice,
init_value, scatter);
xla::XlaOp invalid_iota_init = xla::MaxValue(builder, kIndicesType);
xla::XlaOp invalid_iota =
xla::Broadcast(invalid_iota_init, pooling_op_attributes.kernel_size);
xla::XlaOp scattered_indices = xla::Select(
xla::Ne(scattered_pool, init_value), iota_slice, invalid_iota);
xla::XlaComputation min_computation =
xla::CreateScalarMinComputation(kIndicesType, builder);
xla::XlaOp index =
xla::ReduceWindow(scattered_indices, invalid_iota_init, min_computation,
pooling_op_attributes.kernel_size,
pooling_op_attributes.stride, xla::Padding::kValid);
xla::XlaOp r1_index = xla::Reshape(index, {1});
std::vector<xla::XlaOp> results(init.begin(), init.end());
results[counter_id] = init[counter_id] + xla::One(builder, kIndicesType);
results[result_id] = xla::DynamicUpdateSlice(results[result_id], r1_index,
{init[counter_id]});
return results;
};
std::vector<xla::XlaOp> results = ConsumeValue(
xla::WhileLoopHelper(cond_fn, body_fn, initial_values.values,
"ComputeMaxPoolIndices", padded_input.builder()));
return xla::Reshape(results[result_id], pool_result_shape.dimensions());
}