in runtime/kernels/avgpooling_op.cc [82:172]
void Compute(OpKernelContext* context) override {
VLOG(1) << "Starting AvgPool compute.";
const Tensor& tensor_in = context->input(0);
if (!context->status().ok()) {
return;
}
// For avgpooling, tensor_in should have 4 dimensions.
OP_REQUIRES(context, tensor_in.dims() == 4,
errors::InvalidArgument("tensor_in must be 4-dimensional"));
Tensor* output = nullptr;
int batch_size = tensor_in.shape().dim_size(0);
int rows;
int cols;
int channels;
int newNumRows;
int newNumCols;
int rowKernelSize;
int colKernelSize;
int rowStrideLen;
int colStrideLen;
TensorShape outputshape;
Tensor paddedTensor;
init_basic_info(tensor_in, ksize_, stride_, data_format_ == FORMAT_NHWC,
rows, cols, channels, rowKernelSize, colKernelSize,
rowStrideLen, colStrideLen);
if (padding_ == 1) {
VLOG(1) << "Valid Padding";
valid_padding_new_num_rows_and_cols(rows, cols, rowKernelSize,
colKernelSize, rowStrideLen,
colStrideLen, newNumRows, newNumCols);
} else if (padding_ == 2) {
VLOG(1) << "Same Padding";
same_padding_new_num_rows_and_cols(
tensor_in, paddedTensor, context, rows, cols, batch_size, channels,
rowKernelSize, colKernelSize, rowStrideLen, colStrideLen, newNumRows,
newNumCols, std::numeric_limits<float>::infinity(), data_format_ == FORMAT_NHWC);
} else {
VLOG(1)
<< "unrecognized padding type but not caught durring error checking";
}
VLOG(1) << "paddedTensor shape: " << paddedTensor.DebugString();
if (data_format_ == FORMAT_NHWC) {
outputshape = TensorShape{batch_size, newNumRows, newNumCols, channels};
} else {
outputshape = TensorShape{batch_size, channels, newNumRows, newNumCols};
}
OP_REQUIRES_OK(context, context->allocate_output(0, outputshape, &output));
VLOG(1) << "Output Shape: " << output->DebugString();
auto outmatrix = output->tensor<float, 4>();
// this is a special case where we just average the whole input tensor
if (newNumRows == 1 && newNumCols == 1) {
VLOG(1) << "Using the special case!";
std::function<float(Tensor, int, int, int, int, bool)> special_case_func =
special_case_function;
special_case(tensor_in, output, rows, cols, batch_size, special_case_func, channels,
data_format_ == FORMAT_NHWC);
// don't need to do any more calculations
return;
}
std::function<float(Tensor, int, int, int, int, int, int, bool)>
pooling_func = calculate_sum_and_average;
// valid padding
// need to distinguish based on whether or not to use paddedTensor or
// tensor_in
if (padding_ == 1) {
do_pooling(tensor_in, output, batch_size, newNumRows, rowStrideLen,
rowKernelSize, newNumCols, colStrideLen, colKernelSize,
channels, pooling_func, data_format_ == FORMAT_NHWC);
} else if (padding_ == 2) {
do_pooling(paddedTensor, output, batch_size, newNumRows, rowStrideLen,
rowKernelSize, newNumCols, colStrideLen, colKernelSize,
channels, pooling_func, data_format_ == FORMAT_NHWC);
}
}