in easy_rec/python/layers/cmbf.py [0:0]
def __init__(self, model_config, feature_configs, features, cmbf_config,
input_layer):
self._model_config = cmbf_config
has_feature = False
self._img_features = None
if input_layer.has_group('image'):
self._img_features, _ = input_layer(features, 'image')
has_feature = True
self._general_features = None
if input_layer.has_group('general'):
self._general_features, _ = input_layer(features, 'general')
has_feature = True
self._txt_seq_features = None
if input_layer.has_group('text'):
self._txt_seq_features, _, _ = input_layer(
features, 'text', is_combine=False)
has_feature = True
self._other_features = None
if input_layer.has_group('other'): # e.g. statistical feature
self._other_features, _ = input_layer(features, 'other')
has_feature = True
assert has_feature, 'there must be one of the feature groups: [image, text, general, other]'
self._general_feature_num, self._img_feature_num = 0, 0
self._txt_feature_num = 0
general_feature_names, txt_seq_feature_names = set(), set()
img_feature_names = set()
for fea_group in model_config.feature_groups:
if fea_group.group_name == 'general':
self._general_feature_num = len(fea_group.feature_names)
general_feature_names = set(fea_group.feature_names)
assert self._general_feature_num == len(general_feature_names), (
'there are duplicate features in `general` feature group')
elif fea_group.group_name == 'image':
self._img_feature_num = len(fea_group.feature_names)
img_feature_names = set(fea_group.feature_names)
assert self._img_feature_num == len(img_feature_names), (
'there are duplicate features in `image` feature group')
elif fea_group.group_name == 'text':
txt_seq_feature_names = set(fea_group.feature_names)
self._txt_feature_num = len(fea_group.feature_names)
assert self._txt_feature_num == len(txt_seq_feature_names), (
'there are duplicate features in `text` feature group')
max_seq_len = 0
txt_fea_emb_dim_list = []
general_emb_dim_list = []
img_fea_emb_dim_list = []
for feature_config in feature_configs:
fea_name = feature_config.input_names[0]
if feature_config.HasField('feature_name'):
fea_name = feature_config.feature_name
if fea_name in img_feature_names:
img_fea_emb_dim_list.append(feature_config.raw_input_dim)
if fea_name in general_feature_names:
general_emb_dim_list.append(feature_config.embedding_dim)
if fea_name in txt_seq_feature_names:
txt_fea_emb_dim_list.append(feature_config.embedding_dim)
if feature_config.HasField('max_seq_len'):
assert feature_config.max_seq_len > 0, (
'feature config `max_seq_len` must be greater than 0 for feature: '
+ fea_name)
if feature_config.max_seq_len > max_seq_len:
max_seq_len = feature_config.max_seq_len
unique_dim_num = len(set(txt_fea_emb_dim_list))
assert unique_dim_num <= 1 and len(
txt_fea_emb_dim_list
) == self._txt_feature_num, (
'CMBF requires that all `text` feature dimensions must be consistent.')
unique_dim_num = len(set(general_emb_dim_list))
assert unique_dim_num <= 1 and len(
general_emb_dim_list
) == self._general_feature_num, (
'CMBF requires that all `general` feature dimensions must be consistent.'
)
unique_dim_num = len(set(img_fea_emb_dim_list))
assert unique_dim_num <= 1 and len(
img_fea_emb_dim_list
) == self._img_feature_num, (
'CMBF requires that all `image` feature dimensions must be consistent.')
if cmbf_config.use_position_embeddings:
assert cmbf_config.max_position_embeddings > 0, (
'model config `max_position_embeddings` must be greater than 0. '
'It must be set when `use_position_embeddings` is true (default)')
assert cmbf_config.max_position_embeddings >= max_seq_len, (
'model config `max_position_embeddings` must be greater than or equal to the maximum of all feature config '
'`max_seq_len`, which is %d' % max_seq_len)
self._img_emb_size = img_fea_emb_dim_list[0] if img_fea_emb_dim_list else 0
self._txt_emb_size = txt_fea_emb_dim_list[0] if txt_fea_emb_dim_list else 0
self._general_emb_size = general_emb_dim_list[
0] if general_emb_dim_list else 0
self._head_num = cmbf_config.multi_head_num
self._img_head_num = cmbf_config.image_multi_head_num
self._txt_head_num = cmbf_config.text_multi_head_num
self._txt_head_size = cmbf_config.text_head_size
self._img_head_size = cmbf_config.image_head_size
self._img_patch_num = cmbf_config.image_feature_patch_num
self._img_self_attention_layer_num = cmbf_config.image_self_attention_layer_num
self._txt_self_attention_layer_num = cmbf_config.text_self_attention_layer_num
self._cross_modal_layer_num = cmbf_config.cross_modal_layer_num
print('txt_feature_num: {0}, img_feature_num: {1}, txt_seq_feature_num: {2}'
.format(self._general_feature_num, self._img_feature_num,
len(self._txt_seq_features) if self._txt_seq_features else 0))
print('txt_embedding_size: {0}, img_embedding_size: {1}'.format(
self._txt_emb_size, self._img_emb_size))
if self._img_features is not None:
assert self._img_emb_size > 0, '`image` feature dimensions must be greater than 0, set by `raw_input_dim`'