def entity_concat()

in ma_policy/layers.py [0:0]


def entity_concat(inps):
    '''
        Concat 4D tensors along the third dimension. If a 3D tensor is in the list
            then treat it as a single entity and expand the third dimension
        Args:
            inps (list of tensors): tensors to concatenate
    '''
    with tf.variable_scope('concat_entities'):
        shapes = [shape_list(_x) for _x in inps]
        # For inputs that don't have entity dimension add one.
        inps = [_x if len(_shape) == 4 else tf.expand_dims(_x, 2) for _x, _shape in zip(inps, shapes)]
        shapes = [shape_list(_x) for _x in inps]
        assert np.all([_shape[-1] == shapes[0][-1] for _shape in shapes]),\
            f"Some entities don't have the same outer or inner dimensions {shapes}"
        # Concatenate along entity dimension
        out = tf.concat(inps, -2)
    return out