in egg/core/interaction.py [0:0]
def from_iterable(interactions: Iterable["Interaction"]) -> "Interaction":
"""
>>> a = Interaction(torch.ones(1), None, None, {}, torch.ones(1), torch.ones(1), None, {})
>>> a.size
1
>>> b = Interaction(torch.ones(1), None, None, {}, torch.ones(1), torch.ones(1), None, {})
>>> c = Interaction.from_iterable((a, b))
>>> c.size
2
>>> c
Interaction(sender_input=tensor([1., 1.]), ..., receiver_output=tensor([1., 1.]), message_length=None, aux={})
>>> d = Interaction(torch.ones(1), torch.ones(1), None, {}, torch.ones(1), torch.ones(1), None, {})
>>> _ = Interaction.from_iterable((a, d)) # mishaped, should throw an exception
Traceback (most recent call last):
...
RuntimeError: Appending empty and non-empty interactions logs. Normally this shouldn't happen!
"""
def _check_cat(lst):
if all(x is None for x in lst):
return None
# if some but not all are None: not good
if any(x is None for x in lst):
raise RuntimeError(
"Appending empty and non-empty interactions logs. "
"Normally this shouldn't happen!"
)
return torch.cat(lst, dim=0)
assert interactions, "interaction list must not be empty"
has_aux_input = interactions[0].aux_input is not None
for x in interactions:
assert len(x.aux) == len(interactions[0].aux)
if has_aux_input:
assert len(x.aux_input) == len(
interactions[0].aux_input
), "found two interactions of different aux_info size"
else:
assert (
not x.aux_input
), "some aux_info are defined some are not, this should not happen"
aux_input = None
if has_aux_input:
aux_input = {}
for k in interactions[0].aux_input:
aux_input[k] = _check_cat([x.aux_input[k] for x in interactions])
aux = {}
for k in interactions[0].aux:
aux[k] = _check_cat([x.aux[k] for x in interactions])
return Interaction(
sender_input=_check_cat([x.sender_input for x in interactions]),
receiver_input=_check_cat([x.receiver_input for x in interactions]),
labels=_check_cat([x.labels for x in interactions]),
aux_input=aux_input,
message=_check_cat([x.message for x in interactions]),
message_length=_check_cat([x.message_length for x in interactions]),
receiver_output=_check_cat([x.receiver_output for x in interactions]),
aux=aux,
)