in tensorflow_probability/python/distributions/batch_concat.py [0:0]
def _parameter_control_dependencies(self, is_init):
assertions = []
if is_init:
axis_ = tf.get_static_value(self._axis)
if axis_ is not None and axis_ < 0:
raise ValueError('Axis should be positive, %d was given' % axis_)
if axis_ is None:
assertions.append(tf.assert_greater_equal(axis_, 0))
all_event_shapes = [d.event_shape for d in self._distributions]
if all(tensorshape_util.is_fully_defined(event_shape)
for event_shape in all_event_shapes):
if all_event_shapes[1:] != all_event_shapes[:-1]:
raise ValueError('Distributions must have the same `event_shape`;'
'found: {}' % all_event_shapes)
all_batch_shapes = [d.batch_shape for d in self._distributions]
if all(tensorshape_util.is_fully_defined(batch_shape)
for batch_shape in all_batch_shapes):
batch_shape = all_batch_shapes[0].as_list()
batch_shape[self._axis] = 1
for b in all_batch_shapes[1:]:
b = b.as_list()
if len(batch_shape) != len(b):
raise ValueError('Incompatible batch shape % s with %s' %
(batch_shape, b))
b[self._axis] = 1
tf.broadcast_static_shape(
tensorshape_util.constant_value_as_shape(batch_shape),
tensorshape_util.constant_value_as_shape(b))
if not self.validate_args:
return []
if self.validate_args:
# Validate that event shapes all match.
all_event_shapes = [d.event_shape for d in self._distributions]
if not all(tensorshape_util.is_fully_defined(event_shape)
for event_shape in all_event_shapes):
all_event_shape_tensors = [d.event_shape_tensor() for
d in self._distributions]
def _get_shapes(static_shape, dynamic_shape):
if tensorshape_util.is_fully_defined(static_shape):
return static_shape
else:
return dynamic_shape
event_shapes = tf.nest.map_structure(_get_shapes,
all_event_shapes,
all_event_shape_tensors)
event_shapes = tf.nest.flatten(event_shapes)
assertions.extend(
assert_util.assert_equal(
e1, e2, message='Distributions should have same event shapes.')
for e1, e2 in zip(event_shapes[1:], event_shapes[:-1]))
# Validate that batch shapes are broadcastable and concatenable along
# the specified axis.
if not all(tensorshape_util.is_fully_defined(d.batch_shape)
for d in self._distributions):
for i, d in enumerate(self._distributions[:-1]):
assertions.append(tf.assert_equal(
tf.size(d.batch_shape_tensor()),
tf.size(self._distributions[i+1].batch_shape_tensor())))
batch_shape_tensors = [
ps.tensor_scatter_nd_update(d.batch_shape_tensor(), updates=1,
indices=[self._axis])
for d in self._distributions
]
assertions.append(
functools.reduce(tf.broadcast_dynamic_shape,
batch_shape_tensors[1:],
batch_shape_tensors[:-1]))
return assertions