in lib/action_head.py [0:0]
def forward(self, input_data: torch.Tensor, **kwargs) -> Any:
"""
:param kwargs: each kwarg should be a dict with keys corresponding to self.keys()
e.g. if this ModuleDict has submodules keyed by 'A', 'B', and 'C', we could call:
forward(input_data, foo={'A': True, 'C': False}, bar={'A': 7}}
Then children will be called with:
A: forward(input_data, foo=True, bar=7)
B: forward(input_data)
C: forward(input_Data, foo=False)
"""
result = {}
for head_name, subhead in self.items():
head_kwargs = {
kwarg_name: kwarg[head_name]
for kwarg_name, kwarg in kwargs.items()
if kwarg is not None and head_name in kwarg
}
result[head_name] = subhead(input_data, **head_kwargs)
return result