def _verify_batch_sizes()

in egg/core/reinforce_wrappers.py [0:0]


def _verify_batch_sizes(loss, sender_probs, receiver_probs):
    """Raises an excepption if tensors are not appropriately sized"""
    loss_size, sender_size, receiver_size = (
        loss.size(),
        sender_probs.size(),
        receiver_probs.size(),
    )

    # Most likely you shouldn't have batch size 1, as Reinforce wouldn't work too well
    # but it is not incorrect either
    if loss.numel() == sender_probs.numel() == receiver_probs.numel() == 1:
        return

    is_ok = loss_size and sender_size and loss_size[0] == sender_size[0]

    if not is_ok:
        raise RuntimeError(
            "Does your loss function returns aggregateed loss? When training with Reinforce, "
            "the loss returned by your loss function must have the same batch (first) dimension as "
            "action log-probabilities returned by Sender. However, currently shapes are "
            f"{loss_size} and {sender_size}."
        )

    # As Receiver can be deterministic (and have constant zero log-probs for all its actions)
    # we allow them to be a scalar tensor
    is_receiver_ok = (receiver_probs.numel() == 1 and receiver_probs.item() == 0.0) or (
        receiver_probs.numel() > 1 and receiver_size[0] == loss_size[0]
    )
    if not is_receiver_ok:
        raise RuntimeError(
            "The log-probabilites returned by Receiver must have either the same first dimenstion "
            "as the loss or be a scalar tensor with value 0.0. "
            f"Current shapes are {receiver_size} and {loss_size}."
        )