def _parameter_control_dependencies()

in tensorflow_probability/python/distributions/batch_concat.py [0:0]


  def _parameter_control_dependencies(self, is_init):
    assertions = []

    if is_init:
      axis_ = tf.get_static_value(self._axis)
      if axis_ is not None and axis_ < 0:
        raise ValueError('Axis should be positive, %d was given' % axis_)
      if axis_ is None:
        assertions.append(tf.assert_greater_equal(axis_, 0))

      all_event_shapes = [d.event_shape for d in self._distributions]
      if all(tensorshape_util.is_fully_defined(event_shape)
             for event_shape in all_event_shapes):
        if all_event_shapes[1:] != all_event_shapes[:-1]:
          raise ValueError('Distributions must have the same `event_shape`;'
                           'found: {}' % all_event_shapes)

      all_batch_shapes = [d.batch_shape for d in self._distributions]
      if all(tensorshape_util.is_fully_defined(batch_shape)
             for batch_shape in all_batch_shapes):
        batch_shape = all_batch_shapes[0].as_list()
        batch_shape[self._axis] = 1
        for b in all_batch_shapes[1:]:
          b = b.as_list()
          if len(batch_shape) != len(b):
            raise ValueError('Incompatible batch shape % s with %s' %
                             (batch_shape, b))
          b[self._axis] = 1
          tf.broadcast_static_shape(
              tensorshape_util.constant_value_as_shape(batch_shape),
              tensorshape_util.constant_value_as_shape(b))

    if not self.validate_args:
      return []

    if self.validate_args:
      # Validate that event shapes all match.
      all_event_shapes = [d.event_shape for d in self._distributions]
      if not all(tensorshape_util.is_fully_defined(event_shape)
                 for event_shape in all_event_shapes):
        all_event_shape_tensors = [d.event_shape_tensor() for
                                   d in self._distributions]
        def _get_shapes(static_shape, dynamic_shape):
          if tensorshape_util.is_fully_defined(static_shape):
            return static_shape
          else:
            return dynamic_shape
        event_shapes = tf.nest.map_structure(_get_shapes,
                                             all_event_shapes,
                                             all_event_shape_tensors)
        event_shapes = tf.nest.flatten(event_shapes)
        assertions.extend(
            assert_util.assert_equal(
                e1, e2, message='Distributions should have same event shapes.')
            for e1, e2 in zip(event_shapes[1:], event_shapes[:-1]))

      # Validate that batch shapes are broadcastable and concatenable along
      # the specified axis.
      if not all(tensorshape_util.is_fully_defined(d.batch_shape)
                 for d in self._distributions):
        for i, d in enumerate(self._distributions[:-1]):
          assertions.append(tf.assert_equal(
              tf.size(d.batch_shape_tensor()),
              tf.size(self._distributions[i+1].batch_shape_tensor())))

        batch_shape_tensors = [
            ps.tensor_scatter_nd_update(d.batch_shape_tensor(), updates=1,
                                        indices=[self._axis])
            for d in self._distributions
        ]
        assertions.append(
            functools.reduce(tf.broadcast_dynamic_shape,
                             batch_shape_tensors[1:],
                             batch_shape_tensors[:-1]))
    return assertions