def concat_entity_masks()

in ma_policy/layers.py [0:0]


def concat_entity_masks(inps, masks):
    '''
        Concats masks together. If mask is None, then it creates
            a tensor of 1's with shape (BS, T, NE).
        Args:
            inps (list of tensors): tensors that masks apply to
            masks (list of tensors): corresponding masks
    '''
    assert len(inps) == len(masks), "There should be the same number of inputs as masks"
    with tf.variable_scope('concat_masks'):
        shapes = [shape_list(_x) for _x in inps]
        new_masks = []
        for inp, mask in zip(inps, masks):
            if mask is None:
                inp_shape = shape_list(inp)
                if len(inp_shape) == 4:  # this is an entity tensor
                    new_masks.append(tf.ones(inp_shape[:3]))
                elif len(inp_shape) == 3:  # this is a pooled or main tensor. Set NE (outer dimension) to 1
                    new_masks.append(tf.ones(inp_shape[:2] + [1]))
            else:
                new_masks.append(mask)
        new_mask = tf.concat(new_masks, -1)
    return new_mask