in egg/core/reinforce_wrappers.py [0:0]
def _verify_batch_sizes(loss, sender_probs, receiver_probs):
"""Raises an excepption if tensors are not appropriately sized"""
loss_size, sender_size, receiver_size = (
loss.size(),
sender_probs.size(),
receiver_probs.size(),
)
# Most likely you shouldn't have batch size 1, as Reinforce wouldn't work too well
# but it is not incorrect either
if loss.numel() == sender_probs.numel() == receiver_probs.numel() == 1:
return
is_ok = loss_size and sender_size and loss_size[0] == sender_size[0]
if not is_ok:
raise RuntimeError(
"Does your loss function returns aggregateed loss? When training with Reinforce, "
"the loss returned by your loss function must have the same batch (first) dimension as "
"action log-probabilities returned by Sender. However, currently shapes are "
f"{loss_size} and {sender_size}."
)
# As Receiver can be deterministic (and have constant zero log-probs for all its actions)
# we allow them to be a scalar tensor
is_receiver_ok = (receiver_probs.numel() == 1 and receiver_probs.item() == 0.0) or (
receiver_probs.numel() > 1 and receiver_size[0] == loss_size[0]
)
if not is_receiver_ok:
raise RuntimeError(
"The log-probabilites returned by Receiver must have either the same first dimenstion "
"as the loss or be a scalar tensor with value 0.0. "
f"Current shapes are {receiver_size} and {loss_size}."
)