def __init__()

in easy_rec/python/layers/cmbf.py [0:0]


  def __init__(self, model_config, feature_configs, features, cmbf_config,
               input_layer):
    self._model_config = cmbf_config

    has_feature = False
    self._img_features = None
    if input_layer.has_group('image'):
      self._img_features, _ = input_layer(features, 'image')
      has_feature = True
    self._general_features = None
    if input_layer.has_group('general'):
      self._general_features, _ = input_layer(features, 'general')
      has_feature = True
    self._txt_seq_features = None
    if input_layer.has_group('text'):
      self._txt_seq_features, _, _ = input_layer(
          features, 'text', is_combine=False)
      has_feature = True
    self._other_features = None
    if input_layer.has_group('other'):  # e.g. statistical feature
      self._other_features, _ = input_layer(features, 'other')
      has_feature = True
    assert has_feature, 'there must be one of the feature groups: [image, text, general, other]'

    self._general_feature_num, self._img_feature_num = 0, 0
    self._txt_feature_num = 0
    general_feature_names, txt_seq_feature_names = set(), set()
    img_feature_names = set()
    for fea_group in model_config.feature_groups:
      if fea_group.group_name == 'general':
        self._general_feature_num = len(fea_group.feature_names)
        general_feature_names = set(fea_group.feature_names)
        assert self._general_feature_num == len(general_feature_names), (
            'there are duplicate features in `general` feature group')
      elif fea_group.group_name == 'image':
        self._img_feature_num = len(fea_group.feature_names)
        img_feature_names = set(fea_group.feature_names)
        assert self._img_feature_num == len(img_feature_names), (
            'there are duplicate features in `image` feature group')
      elif fea_group.group_name == 'text':
        txt_seq_feature_names = set(fea_group.feature_names)
        self._txt_feature_num = len(fea_group.feature_names)
        assert self._txt_feature_num == len(txt_seq_feature_names), (
            'there are duplicate features in `text` feature group')

    max_seq_len = 0
    txt_fea_emb_dim_list = []
    general_emb_dim_list = []
    img_fea_emb_dim_list = []
    for feature_config in feature_configs:
      fea_name = feature_config.input_names[0]
      if feature_config.HasField('feature_name'):
        fea_name = feature_config.feature_name
      if fea_name in img_feature_names:
        img_fea_emb_dim_list.append(feature_config.raw_input_dim)
      if fea_name in general_feature_names:
        general_emb_dim_list.append(feature_config.embedding_dim)
      if fea_name in txt_seq_feature_names:
        txt_fea_emb_dim_list.append(feature_config.embedding_dim)
        if feature_config.HasField('max_seq_len'):
          assert feature_config.max_seq_len > 0, (
              'feature config `max_seq_len` must be greater than 0 for feature: '
              + fea_name)
          if feature_config.max_seq_len > max_seq_len:
            max_seq_len = feature_config.max_seq_len

    unique_dim_num = len(set(txt_fea_emb_dim_list))
    assert unique_dim_num <= 1 and len(
        txt_fea_emb_dim_list
    ) == self._txt_feature_num, (
        'CMBF requires that all `text` feature dimensions must be consistent.')
    unique_dim_num = len(set(general_emb_dim_list))
    assert unique_dim_num <= 1 and len(
        general_emb_dim_list
    ) == self._general_feature_num, (
        'CMBF requires that all `general` feature dimensions must be consistent.'
    )
    unique_dim_num = len(set(img_fea_emb_dim_list))
    assert unique_dim_num <= 1 and len(
        img_fea_emb_dim_list
    ) == self._img_feature_num, (
        'CMBF requires that all `image` feature dimensions must be consistent.')

    if cmbf_config.use_position_embeddings:
      assert cmbf_config.max_position_embeddings > 0, (
          'model config `max_position_embeddings` must be greater than 0. '
          'It must be set when `use_position_embeddings` is true (default)')
      assert cmbf_config.max_position_embeddings >= max_seq_len, (
          'model config `max_position_embeddings` must be greater than or equal to the maximum of all feature config '
          '`max_seq_len`, which is %d' % max_seq_len)

    self._img_emb_size = img_fea_emb_dim_list[0] if img_fea_emb_dim_list else 0
    self._txt_emb_size = txt_fea_emb_dim_list[0] if txt_fea_emb_dim_list else 0
    self._general_emb_size = general_emb_dim_list[
        0] if general_emb_dim_list else 0
    self._head_num = cmbf_config.multi_head_num
    self._img_head_num = cmbf_config.image_multi_head_num
    self._txt_head_num = cmbf_config.text_multi_head_num
    self._txt_head_size = cmbf_config.text_head_size
    self._img_head_size = cmbf_config.image_head_size
    self._img_patch_num = cmbf_config.image_feature_patch_num
    self._img_self_attention_layer_num = cmbf_config.image_self_attention_layer_num
    self._txt_self_attention_layer_num = cmbf_config.text_self_attention_layer_num
    self._cross_modal_layer_num = cmbf_config.cross_modal_layer_num
    print('txt_feature_num: {0}, img_feature_num: {1}, txt_seq_feature_num: {2}'
          .format(self._general_feature_num, self._img_feature_num,
                  len(self._txt_seq_features) if self._txt_seq_features else 0))
    print('txt_embedding_size: {0}, img_embedding_size: {1}'.format(
        self._txt_emb_size, self._img_emb_size))
    if self._img_features is not None:
      assert self._img_emb_size > 0, '`image` feature dimensions must be greater than 0, set by `raw_input_dim`'