def call()

in tensorflow_gnn/graph/keras/layers/map_features.py [0:0]


  def call(self, graph: gt.GraphTensor) -> gt.GraphTensor:
    if not self._is_initialized:
      with tf.init_scope():
        self._init_from_spec(graph.spec)
        self._context_fn = self._node_sets_fn = self._edge_sets_fn = None
    assert self._is_initialized

    context_features = None
    if self._context_model is not None:
      context_features = _call_model(self._context_model, graph.context,
                                     logging_name="context")

    node_set_features = {}
    for node_set_name, node_set in graph.node_sets.items():
      try:
        model = self._node_set_models[node_set_name]
      except KeyError as e:
        raise KeyError(f"Unexpected node set '{node_set_name}' "
                       "not seen in first call") from e
      if model is None: continue  # Initialized to be ignored.
      node_set_features[node_set_name] = _call_model(
          model, node_set, logging_name=f"node_set '{node_set_name}'")

    edge_set_features = {}
    for edge_set_name, edge_set in graph.edge_sets.items():
      try:
        model = self._edge_set_models[edge_set_name]
      except KeyError as e:
        raise KeyError(f"Unexpected edge set '{edge_set_name}' "
                       "not seen in first call") from e
      if model is None: continue  # Initialized to be ignored.
      edge_set_features[edge_set_name] = _call_model(
          model, edge_set, logging_name=f"edge_set '{edge_set_name}'")

    result = graph.replace_features(context=context_features,
                                    node_sets=node_set_features,
                                    edge_sets=edge_set_features)
    return result