in crypten/common/functions/pooling.py [0:0]
def _adaptive_pool2d_helper(input, output_size, reduction="mean"):
r"""
Provides a helper that adapts the input size and provides input
args / kwargs to allow pool2d functions to emulate adaptive pool2d
functions.
This function computes the kernel_size, stride, and padding for
pool2d functions and inserts rows along each dimension so that
a constant stride can be used.
"""
import crypten
input = input.clone()
if isinstance(output_size, int):
output_size = (output_size, output_size)
assert len(output_size) == 2, "output_size must be 2-dimensional."
output_size = list(output_size)
for i in range(2):
if output_size[i] is None:
output_size[i] = input.size(i - 2)
# Compute the start_index and end_index for kernels
def compute_kernels(in_size, out_size):
step = in_size / out_size
starts = []
ends = []
max_kernel_size = 0
for j in range(out_size):
# Compute local kernel size
start_index = int(j * step)
end_index = int(math.ceil((j + 1) * step))
k = end_index - start_index
# Update global kernel size
max_kernel_size = k if k > max_kernel_size else max_kernel_size
# Store local kernels
starts.append(start_index)
ends.append(end_index)
return starts, ends, max_kernel_size
# Repeats a row `ind` of `tensor` at dimension `dim` for overlapping kernels
def repeat_row(tensor, dim, ind):
device = tensor.device
x = tensor.index_select(dim, torch.arange(ind, device=device))
y = tensor.index_select(dim, torch.arange(ind, tensor.size(dim), device=device))
repeated_row = tensor.index_select(dim, torch.tensor(ind - 1, device=device))
return crypten.cat([x, repeated_row, y], dim=dim)
# Extends a row where a kernel is smaller than the maximum kernel size
def extend_row(tensor, dim, start_ind, end_ind):
device = tensor.device
if reduction == "mean":
extended_value = tensor.index_select(
dim, torch.arange(start_ind, end_ind, device=device)
)
extended_value = extended_value.mean(dim, keepdim=True)
elif reduction == "max":
extended_value = tensor.index_select(
dim, torch.tensor(start_ind, device=device)
)
else:
raise ValueError(f"Invalid reduction {reduction} for adaptive pooling.")
if start_ind == 0:
return crypten.cat([extended_value, tensor], dim=dim)
x = tensor.index_select(dim, torch.arange(start_ind, device=device))
y = tensor.index_select(
dim, torch.arange(start_ind, tensor.size(dim), device=device)
)
return crypten.cat([x, extended_value, y], dim=dim)
strides = []
for i in range(2):
dim = i - 2 + input.dim()
in_size = input.size(dim)
out_size = output_size[i] if output_size[i] is not None else in_size
# Compute repeats
if out_size > 1:
starts, ends, stride = compute_kernels(in_size, out_size)
added_rows = 0
for i in range(out_size):
start_ind = starts[i]
end_ind = ends[i]
# Extend kernel so all kernels have the same size
k = end_ind - start_ind
for _ in range(k, stride):
input = extend_row(
input, dim, start_ind + added_rows, end_ind + added_rows
)
added_rows += 1
if i == out_size - 1:
break
# Repeat overlapping rows so stride can be equal to the kernel size
if end_ind > starts[i + 1]:
input = repeat_row(input, dim, end_ind + added_rows)
added_rows += 1
else:
stride = in_size
strides.append(stride)
strides = tuple(strides)
kernel_sizes = strides
args = (kernel_sizes,)
kwargs = {"stride": strides}
return input, args, kwargs