def _nn_functional_embedding_bag()

in nestedtensor/nested/nested.py [0:0]


def _nn_functional_embedding_bag(input, weight, offsets=None, max_norm=None, norm_type=2,
                                 scale_grad_by_freq=False, mode='mean', sparse=False,
                                 per_sample_weights=None, include_last_offset=False,
                                 padding_idx=None):
    # Check for backward compatibility.
    # Used to be embedding_bag(weight, input, ...)
    # Now is     embedding_bag(input, weight, ...)
    if weight.dtype == torch.long and input.is_floating_point():
        warnings.warn("Argument order of nn.functional.embedding_bag was changed. "
                      "Usage `embedding_bag(weight, input, ...)` is deprecated, "
                      "and should now be `embedding_bag(input, weight, ...)`.")
        weight, input = input, weight

    if per_sample_weights is not None and input.size() != per_sample_weights.size():
        raise ValueError("embedding_bag: If per_sample_weights ({}) is not None, "
                         "then it must have the same shape as the input ({})"
                         .format(per_sample_weights.shape, input.shape))

    _not_impl_raise(max_norm, "max_norm")
    _not_impl_raise(per_sample_weights, "per_sample_weights")

    input_dim = torch.ops.nestedtensor.get_dim(input)
    if input_dim == 2:
        if offsets is not None:
            type_str = "<unknown>"
            # TODO: Remove this once script supports type() calls
            if not torch.jit.is_scripting():
                type_str = str(type(offsets))
            raise ValueError("if input is 2D, then offsets has to be None"
                             ", as input is treated is a mini-batch of"
                             " fixed length sequences. However, found "
                             "offsets of type {}".format(type_str))
        offsets_ = NestedTensor(input).nested_size()
        offsets = torch.zeros(len(offsets_), dtype=torch.int64)
        for i in range(1, len(offsets)):
            offsets[i] = offsets[i - 1] + offsets_[i - 1][0]
        offsets = offsets.to(input.device)
    elif input_dim == 1:
        raise ValueError("input has to be 2D NestedTensor,"
                         " but got NestedTensor of dimension {}".format(input_dim))
    if mode == 'sum':
        mode_enum = 0
    elif mode == 'mean':
        mode_enum = 1
    elif mode == 'max':
        mode_enum = 2

        if scale_grad_by_freq:
            raise ValueError(
                "max mode does not support scaling the gradient by the frequency")

        if sparse:
            raise ValueError("max mode does not support sparse weights")

    else:
        raise ValueError("mode has to be one of sum, mean or max")

    if per_sample_weights is not None and mode != 'sum':
        raise NotImplementedError("embedding_bag: per_sample_weights was not None. "
                                  "per_sample_weights is only supported for mode='sum' "
                                  "(got mode='{}'). Please open a feature request on GitHub."
                                  .format(mode))
    if padding_idx is not None:
        raise NotImplementedError(
            "padding_idx is not supported for NestedTensor embedding_bag")

    ret, _, _, _ = torch.embedding_bag(
        weight,
        input,
        offsets,
        scale_grad_by_freq,
        mode_enum,
        sparse,
        per_sample_weights,
        include_last_offset)
    return ret