def logsumexp()

in multiple_futures_prediction/my_utils.py [0:0]


def logsumexp(inputs: torch.Tensor, dim: Optional[int] =None, keepdim: Optional[bool] =False) -> torch.Tensor:
  if dim is None:
    inputs = inputs.view(-1)
    dim = 0
  s, _ = torch.max(inputs, dim=dim, keepdim=True)
  outputs = s + (inputs - s).exp().sum(dim=dim, keepdim=True).log()
  if not keepdim:
    outputs = outputs.squeeze(dim)
  return outputs