in ma_policy/layers.py [0:0]
def entity_concat(inps):
'''
Concat 4D tensors along the third dimension. If a 3D tensor is in the list
then treat it as a single entity and expand the third dimension
Args:
inps (list of tensors): tensors to concatenate
'''
with tf.variable_scope('concat_entities'):
shapes = [shape_list(_x) for _x in inps]
# For inputs that don't have entity dimension add one.
inps = [_x if len(_shape) == 4 else tf.expand_dims(_x, 2) for _x, _shape in zip(inps, shapes)]
shapes = [shape_list(_x) for _x in inps]
assert np.all([_shape[-1] == shapes[0][-1] for _shape in shapes]),\
f"Some entities don't have the same outer or inner dimensions {shapes}"
# Concatenate along entity dimension
out = tf.concat(inps, -2)
return out