def _adaptive_pool2d_helper()

in crypten/common/functions/pooling.py [0:0]


def _adaptive_pool2d_helper(input, output_size, reduction="mean"):
    r"""
    Provides a helper that adapts the input size and provides input
    args / kwargs to allow pool2d functions to emulate adaptive pool2d
    functions.

    This function computes the kernel_size, stride, and padding for
    pool2d functions and inserts rows along each dimension so that
    a constant stride can be used.
    """
    import crypten

    input = input.clone()

    if isinstance(output_size, int):
        output_size = (output_size, output_size)

    assert len(output_size) == 2, "output_size must be 2-dimensional."

    output_size = list(output_size)
    for i in range(2):
        if output_size[i] is None:
            output_size[i] = input.size(i - 2)

    # Compute the start_index and end_index for kernels
    def compute_kernels(in_size, out_size):
        step = in_size / out_size

        starts = []
        ends = []
        max_kernel_size = 0
        for j in range(out_size):
            # Compute local kernel size
            start_index = int(j * step)
            end_index = int(math.ceil((j + 1) * step))
            k = end_index - start_index

            # Update global kernel size
            max_kernel_size = k if k > max_kernel_size else max_kernel_size

            # Store local kernels
            starts.append(start_index)
            ends.append(end_index)

        return starts, ends, max_kernel_size

    # Repeats a row `ind` of `tensor` at dimension `dim` for overlapping kernels
    def repeat_row(tensor, dim, ind):
        device = tensor.device
        x = tensor.index_select(dim, torch.arange(ind, device=device))
        y = tensor.index_select(dim, torch.arange(ind, tensor.size(dim), device=device))
        repeated_row = tensor.index_select(dim, torch.tensor(ind - 1, device=device))
        return crypten.cat([x, repeated_row, y], dim=dim)

    # Extends a row where a kernel is smaller than the maximum kernel size
    def extend_row(tensor, dim, start_ind, end_ind):
        device = tensor.device
        if reduction == "mean":
            extended_value = tensor.index_select(
                dim, torch.arange(start_ind, end_ind, device=device)
            )
            extended_value = extended_value.mean(dim, keepdim=True)
        elif reduction == "max":
            extended_value = tensor.index_select(
                dim, torch.tensor(start_ind, device=device)
            )
        else:
            raise ValueError(f"Invalid reduction {reduction} for adaptive pooling.")

        if start_ind == 0:
            return crypten.cat([extended_value, tensor], dim=dim)

        x = tensor.index_select(dim, torch.arange(start_ind, device=device))
        y = tensor.index_select(
            dim, torch.arange(start_ind, tensor.size(dim), device=device)
        )
        return crypten.cat([x, extended_value, y], dim=dim)

    strides = []
    for i in range(2):
        dim = i - 2 + input.dim()
        in_size = input.size(dim)
        out_size = output_size[i] if output_size[i] is not None else in_size

        # Compute repeats
        if out_size > 1:
            starts, ends, stride = compute_kernels(in_size, out_size)

            added_rows = 0
            for i in range(out_size):
                start_ind = starts[i]
                end_ind = ends[i]

                # Extend kernel so all kernels have the same size
                k = end_ind - start_ind
                for _ in range(k, stride):
                    input = extend_row(
                        input, dim, start_ind + added_rows, end_ind + added_rows
                    )
                    added_rows += 1

                if i == out_size - 1:
                    break

                # Repeat overlapping rows so stride can be equal to the kernel size
                if end_ind > starts[i + 1]:
                    input = repeat_row(input, dim, end_ind + added_rows)
                    added_rows += 1
        else:
            stride = in_size

        strides.append(stride)

    strides = tuple(strides)
    kernel_sizes = strides

    args = (kernel_sizes,)
    kwargs = {"stride": strides}

    return input, args, kwargs