xla::XlaOp ComputeMaxPoolIndices()

in torch_xla/csrc/pooling.cpp [252:347]


xla::XlaOp ComputeMaxPoolIndices(
    const xla::Shape& input_shape, xla::XlaOp padded_input,
    xla::XlaOp pool_result, const xla::PaddingConfig& padding_config,
    const PoolingOpAttributes& pooling_op_attributes) {
  if (!IsOverlapping(pooling_op_attributes)) {
    // The algorithm in ComputeNoOverlapMaxPoolIndices() only works if reduce
    // windows do not overlap. If they do, the reduce-window done on the indices
    // will find multiple indices within the window, and won't know what to
    // select. If XLA had a vardic reduce-window we could do that.
    return ComputeNoOverlapMaxPoolIndices(input_shape, padded_input,
                                          pool_result, padding_config,
                                          pooling_op_attributes);
  }

  // Slow version follows.
  // We loop through every window and compute the index. The slow code will only
  // be executed if the caller actually uses the indices, and only if the reduce
  // windows overlap.
  xla::XlaOp iota = CreatePoolIndicesIota(input_shape, padded_input.builder());
  xla::XlaOp padded_iota =
      xla::Pad(iota, xla::MaxValue(padded_input.builder(), kIndicesType),
               padding_config);

  const xla::Shape& pool_result_shape = XlaHelpers::ShapeOfXlaOp(pool_result);
  xla::int64_t pool_elements = xla::ShapeUtil::ElementsIn(pool_result_shape);

  InitValues initial_values;
  size_t counter_id =
      initial_values.append(xla::Zero(padded_input.builder(), kIndicesType));
  size_t limit_id = initial_values.append(XlaHelpers::ScalarValue(
      pool_elements, kIndicesType, padded_input.builder()));
  size_t input_id = initial_values.append(padded_input);
  size_t pool_result_id = initial_values.append(pool_result);
  size_t iota_id = initial_values.append(padded_iota);
  size_t result_id = initial_values.append(
      xla::Zeros(padded_input.builder(),
                 xla::ShapeUtil::MakeShape(kIndicesType, {pool_elements})));

  auto cond_fn = [&](absl::Span<const xla::XlaOp> init,
                     xla::XlaBuilder* builder) -> xla::StatusOr<xla::XlaOp> {
    return xla::Lt(init[counter_id], init[limit_id]);
  };
  auto body_fn =
      [&](absl::Span<const xla::XlaOp> init,
          xla::XlaBuilder* builder) -> xla::StatusOr<std::vector<xla::XlaOp>> {
    PoolSliceIndices slice_indices =
        ComputeSliceIndices(init[counter_id], pool_result_shape.dimensions(),
                            pooling_op_attributes.stride);

    xla::XlaOp input_slice =
        xla::DynamicSlice(init[input_id], slice_indices.input_indices,
                          pooling_op_attributes.kernel_size);
    xla::XlaOp iota_slice =
        xla::DynamicSlice(init[iota_id], slice_indices.input_indices,
                          pooling_op_attributes.kernel_size);
    std::vector<xla::int64_t> result_slice_sizes(
        pooling_op_attributes.kernel_size.size(), 1);
    xla::XlaOp pool_result_slice = xla::DynamicSlice(
        init[pool_result_id], slice_indices.result_indices, result_slice_sizes);

    xla::XlaComputation select =
        xla::CreateScalarGeComputation(input_shape.element_type(), builder);
    xla::XlaComputation scatter =
        xla::CreateScalarMaxComputation(input_shape.element_type(), builder);
    xla::XlaOp init_value = xla::MinValue(builder, input_shape.element_type());
    xla::XlaOp scattered_pool = xla::SelectAndScatter(
        input_slice, select, pooling_op_attributes.kernel_size,
        pooling_op_attributes.stride, xla::Padding::kValid, pool_result_slice,
        init_value, scatter);

    xla::XlaOp invalid_iota_init = xla::MaxValue(builder, kIndicesType);
    xla::XlaOp invalid_iota =
        xla::Broadcast(invalid_iota_init, pooling_op_attributes.kernel_size);
    xla::XlaOp scattered_indices = xla::Select(
        xla::Ne(scattered_pool, init_value), iota_slice, invalid_iota);
    xla::XlaComputation min_computation =
        xla::CreateScalarMinComputation(kIndicesType, builder);
    xla::XlaOp index =
        xla::ReduceWindow(scattered_indices, invalid_iota_init, min_computation,
                          pooling_op_attributes.kernel_size,
                          pooling_op_attributes.stride, xla::Padding::kValid);
    xla::XlaOp r1_index = xla::Reshape(index, {1});

    std::vector<xla::XlaOp> results(init.begin(), init.end());
    results[counter_id] = init[counter_id] + xla::One(builder, kIndicesType);
    results[result_id] = xla::DynamicUpdateSlice(results[result_id], r1_index,
                                                 {init[counter_id]});
    return results;
  };

  std::vector<xla::XlaOp> results = ConsumeValue(
      xla::WhileLoopHelper(cond_fn, body_fn, initial_values.values,
                           "ComputeMaxPoolIndices", padded_input.builder()));

  return xla::Reshape(results[result_id], pool_result_shape.dimensions());
}