in multiple_futures_prediction/my_utils.py [0:0]
def logsumexp(inputs: torch.Tensor, dim: Optional[int] =None, keepdim: Optional[bool] =False) -> torch.Tensor:
if dim is None:
inputs = inputs.view(-1)
dim = 0
s, _ = torch.max(inputs, dim=dim, keepdim=True)
outputs = s + (inputs - s).exp().sum(dim=dim, keepdim=True).log()
if not keepdim:
outputs = outputs.squeeze(dim)
return outputs