in ma_policy/layers.py [0:0]
def concat_entity_masks(inps, masks):
'''
Concats masks together. If mask is None, then it creates
a tensor of 1's with shape (BS, T, NE).
Args:
inps (list of tensors): tensors that masks apply to
masks (list of tensors): corresponding masks
'''
assert len(inps) == len(masks), "There should be the same number of inputs as masks"
with tf.variable_scope('concat_masks'):
shapes = [shape_list(_x) for _x in inps]
new_masks = []
for inp, mask in zip(inps, masks):
if mask is None:
inp_shape = shape_list(inp)
if len(inp_shape) == 4: # this is an entity tensor
new_masks.append(tf.ones(inp_shape[:3]))
elif len(inp_shape) == 3: # this is a pooled or main tensor. Set NE (outer dimension) to 1
new_masks.append(tf.ones(inp_shape[:2] + [1]))
else:
new_masks.append(mask)
new_mask = tf.concat(new_masks, -1)
return new_mask