in easy_rec/python/layers/uniter.py [0:0]
def __init__(self, model_config, feature_configs, features, uniter_config,
input_layer):
self._model_config = uniter_config
tower_num = 0
self._img_features = None
if input_layer.has_group('image'):
self._img_features, _ = input_layer(features, 'image')
tower_num += 1
self._general_features = None
if input_layer.has_group('general'):
self._general_features, _ = input_layer(features, 'general')
tower_num += 1
self._txt_seq_features = None
if input_layer.has_group('text'):
self._txt_seq_features, _, _ = input_layer(
features, 'text', is_combine=False)
tower_num += 1
self._use_token_type = True if tower_num > 1 else False
self._other_features = None
if input_layer.has_group('other'): # e.g. statistical feature
self._other_features, _ = input_layer(features, 'other')
tower_num += 1
assert tower_num > 0, 'there must be one of the feature groups: [image, text, general, other]'
self._general_feature_num = 0
self._txt_feature_num, self._img_feature_num = 0, 0
general_feature_names = set()
img_feature_names, txt_feature_names = set(), set()
for fea_group in model_config.feature_groups:
if fea_group.group_name == 'general':
self._general_feature_num = len(fea_group.feature_names)
general_feature_names = set(fea_group.feature_names)
assert self._general_feature_num == len(general_feature_names), (
'there are duplicate features in `general` feature group')
elif fea_group.group_name == 'image':
self._img_feature_num = len(fea_group.feature_names)
img_feature_names = set(fea_group.feature_names)
assert self._img_feature_num == len(img_feature_names), (
'there are duplicate features in `image` feature group')
elif fea_group.group_name == 'text':
self._txt_feature_num = len(fea_group.feature_names)
txt_feature_names = set(fea_group.feature_names)
assert self._txt_feature_num == len(txt_feature_names), (
'there are duplicate features in `text` feature group')
if self._txt_feature_num > 1 or self._img_feature_num > 1:
self._use_token_type = True
self._token_type_vocab_size = self._txt_feature_num
if self._img_feature_num > 0:
self._token_type_vocab_size += 1
if self._general_feature_num > 0:
self._token_type_vocab_size += 1
max_seq_len = 0
txt_fea_emb_dim_list = []
general_emb_dim_list = []
img_fea_emb_dim_list = []
for feature_config in feature_configs:
fea_name = feature_config.input_names[0]
if feature_config.HasField('feature_name'):
fea_name = feature_config.feature_name
if fea_name in img_feature_names:
img_fea_emb_dim_list.append(feature_config.raw_input_dim)
if fea_name in general_feature_names:
general_emb_dim_list.append(feature_config.embedding_dim)
if fea_name in txt_feature_names:
txt_fea_emb_dim_list.append(feature_config.embedding_dim)
if feature_config.HasField('max_seq_len'):
assert feature_config.max_seq_len > 0, (
'feature config `max_seq_len` must be greater than 0 for feature: '
+ fea_name)
if feature_config.max_seq_len > max_seq_len:
max_seq_len = feature_config.max_seq_len
unique_dim_num = len(set(txt_fea_emb_dim_list))
assert unique_dim_num <= 1 and len(
txt_fea_emb_dim_list
) == self._txt_feature_num, (
'Uniter requires that all `text` feature dimensions must be consistent.'
)
unique_dim_num = len(set(img_fea_emb_dim_list))
assert unique_dim_num <= 1 and len(
img_fea_emb_dim_list
) == self._img_feature_num, (
'Uniter requires that all `image` feature dimensions must be consistent.'
)
unique_dim_num = len(set(general_emb_dim_list))
assert unique_dim_num <= 1 and len(
general_emb_dim_list
) == self._general_feature_num, (
'Uniter requires that all `general` feature dimensions must be consistent.'
)
if self._txt_feature_num > 0 and uniter_config.use_position_embeddings:
assert uniter_config.max_position_embeddings > 0, (
'model config `max_position_embeddings` must be greater than 0. ')
assert uniter_config.max_position_embeddings >= max_seq_len, (
'model config `max_position_embeddings` must be greater than or equal to the maximum of all feature config '
'`max_seq_len`, which is %d' % max_seq_len)
self._img_emb_size = img_fea_emb_dim_list[0] if img_fea_emb_dim_list else 0
self._txt_emb_size = txt_fea_emb_dim_list[0] if txt_fea_emb_dim_list else 0
self._general_emb_size = general_emb_dim_list[
0] if general_emb_dim_list else 0
if self._img_features is not None:
assert self._img_emb_size > 0, '`image` feature dimensions must be greater than 0, set by `raw_input_dim`'