def __init__()

in easy_rec/python/layers/uniter.py [0:0]


  def __init__(self, model_config, feature_configs, features, uniter_config,
               input_layer):
    self._model_config = uniter_config
    tower_num = 0
    self._img_features = None
    if input_layer.has_group('image'):
      self._img_features, _ = input_layer(features, 'image')
      tower_num += 1
    self._general_features = None
    if input_layer.has_group('general'):
      self._general_features, _ = input_layer(features, 'general')
      tower_num += 1
    self._txt_seq_features = None
    if input_layer.has_group('text'):
      self._txt_seq_features, _, _ = input_layer(
          features, 'text', is_combine=False)
      tower_num += 1
    self._use_token_type = True if tower_num > 1 else False
    self._other_features = None
    if input_layer.has_group('other'):  # e.g. statistical feature
      self._other_features, _ = input_layer(features, 'other')
      tower_num += 1
    assert tower_num > 0, 'there must be one of the feature groups: [image, text, general, other]'

    self._general_feature_num = 0
    self._txt_feature_num, self._img_feature_num = 0, 0
    general_feature_names = set()
    img_feature_names, txt_feature_names = set(), set()
    for fea_group in model_config.feature_groups:
      if fea_group.group_name == 'general':
        self._general_feature_num = len(fea_group.feature_names)
        general_feature_names = set(fea_group.feature_names)
        assert self._general_feature_num == len(general_feature_names), (
            'there are duplicate features in `general` feature group')
      elif fea_group.group_name == 'image':
        self._img_feature_num = len(fea_group.feature_names)
        img_feature_names = set(fea_group.feature_names)
        assert self._img_feature_num == len(img_feature_names), (
            'there are duplicate features in `image` feature group')
      elif fea_group.group_name == 'text':
        self._txt_feature_num = len(fea_group.feature_names)
        txt_feature_names = set(fea_group.feature_names)
        assert self._txt_feature_num == len(txt_feature_names), (
            'there are duplicate features in `text` feature group')

    if self._txt_feature_num > 1 or self._img_feature_num > 1:
      self._use_token_type = True
    self._token_type_vocab_size = self._txt_feature_num
    if self._img_feature_num > 0:
      self._token_type_vocab_size += 1
    if self._general_feature_num > 0:
      self._token_type_vocab_size += 1

    max_seq_len = 0
    txt_fea_emb_dim_list = []
    general_emb_dim_list = []
    img_fea_emb_dim_list = []
    for feature_config in feature_configs:
      fea_name = feature_config.input_names[0]
      if feature_config.HasField('feature_name'):
        fea_name = feature_config.feature_name
      if fea_name in img_feature_names:
        img_fea_emb_dim_list.append(feature_config.raw_input_dim)
      if fea_name in general_feature_names:
        general_emb_dim_list.append(feature_config.embedding_dim)
      if fea_name in txt_feature_names:
        txt_fea_emb_dim_list.append(feature_config.embedding_dim)
        if feature_config.HasField('max_seq_len'):
          assert feature_config.max_seq_len > 0, (
              'feature config `max_seq_len` must be greater than 0 for feature: '
              + fea_name)
          if feature_config.max_seq_len > max_seq_len:
            max_seq_len = feature_config.max_seq_len

    unique_dim_num = len(set(txt_fea_emb_dim_list))
    assert unique_dim_num <= 1 and len(
        txt_fea_emb_dim_list
    ) == self._txt_feature_num, (
        'Uniter requires that all `text` feature dimensions must be consistent.'
    )
    unique_dim_num = len(set(img_fea_emb_dim_list))
    assert unique_dim_num <= 1 and len(
        img_fea_emb_dim_list
    ) == self._img_feature_num, (
        'Uniter requires that all `image` feature dimensions must be consistent.'
    )
    unique_dim_num = len(set(general_emb_dim_list))
    assert unique_dim_num <= 1 and len(
        general_emb_dim_list
    ) == self._general_feature_num, (
        'Uniter requires that all `general` feature dimensions must be consistent.'
    )

    if self._txt_feature_num > 0 and uniter_config.use_position_embeddings:
      assert uniter_config.max_position_embeddings > 0, (
          'model config `max_position_embeddings` must be greater than 0. ')
      assert uniter_config.max_position_embeddings >= max_seq_len, (
          'model config `max_position_embeddings` must be greater than or equal to the maximum of all feature config '
          '`max_seq_len`, which is %d' % max_seq_len)

    self._img_emb_size = img_fea_emb_dim_list[0] if img_fea_emb_dim_list else 0
    self._txt_emb_size = txt_fea_emb_dim_list[0] if txt_fea_emb_dim_list else 0
    self._general_emb_size = general_emb_dim_list[
        0] if general_emb_dim_list else 0
    if self._img_features is not None:
      assert self._img_emb_size > 0, '`image` feature dimensions must be greater than 0, set by `raw_input_dim`'