pyro/distributions/score_parts.py (8 lines of code) (raw):
# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0
from collections import namedtuple
from pyro.distributions.util import scale_and_mask
class ScoreParts(namedtuple('ScoreParts', ['log_prob', 'score_function', 'entropy_term'])):
"""
This data structure stores terms used in stochastic gradient estimators that
combine the pathwise estimator and the score function estimator.
"""
def scale_and_mask(self, scale=1.0, mask=None):
"""
Scale and mask appropriate terms of a gradient estimator by a data multiplicity factor.
Note that the `score_function` term should not be scaled or masked.
:param scale: a positive scale
:type scale: torch.Tensor or number
:param mask: an optional masking tensor
:type mask: torch.BoolTensor or None
"""
log_prob = scale_and_mask(self.log_prob, scale, mask)
score_function = self.score_function # not scaled
entropy_term = scale_and_mask(self.entropy_term, scale, mask)
return ScoreParts(log_prob, score_function, entropy_term)