pyro/ops/einsum/torch_map.py (36 lines of code) (raw):

# Copyright (c) 2017-2019 Uber Technologies, Inc. # SPDX-License-Identifier: Apache-2.0 import operator from functools import reduce from pyro.ops import packed from pyro.ops.einsum.adjoint import Backward, einsum_backward_sample, transpose, unflatten from pyro.ops.einsum.util import Tensordot class _EinsumBackward(Backward): def __init__(self, operands, argmax): self.operands = operands self.argmax = argmax def process(self, message): sample1 = self.argmax sample2 = message return einsum_backward_sample(self.operands, sample1, sample2) def einsum(equation, *operands): """ Forward-max-sum backward-argmax implementation of einsum. This assumes all operands have a ``._pyro_dims`` attribute set. """ equation = packed.rename_equation(equation, *operands) inputs, output = equation.split('->') any_requires_backward = any(hasattr(x, '_pyro_backward') for x in operands) contract_dims = ''.join(sorted(set().union(*(x._pyro_dims for x in operands)) - set(output))) dims = output + contract_dims result = reduce(operator.add, packed.broadcast_all(*operands, dims=dims)) argmax = None # work around lack of pytorch support for zero-sized tensors if contract_dims: output_shape = result.shape[:len(output)] contract_shape = result.shape[len(output):] result, argmax = result.reshape(output_shape + (-1,)).max(-1) if any_requires_backward: argmax = unflatten(argmax, output, contract_dims, contract_shape) elif result is operands[0]: result = result[...] # create a new object result._pyro_dims = output assert result.dim() == len(result._pyro_dims) if any_requires_backward: result._pyro_backward = _EinsumBackward(operands, argmax) return result tensordot = Tensordot(einsum) __all__ = ["transpose", "einsum", "tensordot"]