in tensorflow_gnn/graph/keras/layers/map_features.py [0:0]
def call(self, graph: gt.GraphTensor) -> gt.GraphTensor:
if not self._is_initialized:
with tf.init_scope():
self._init_from_spec(graph.spec)
self._context_fn = self._node_sets_fn = self._edge_sets_fn = None
assert self._is_initialized
context_features = None
if self._context_model is not None:
context_features = _call_model(self._context_model, graph.context,
logging_name="context")
node_set_features = {}
for node_set_name, node_set in graph.node_sets.items():
try:
model = self._node_set_models[node_set_name]
except KeyError as e:
raise KeyError(f"Unexpected node set '{node_set_name}' "
"not seen in first call") from e
if model is None: continue # Initialized to be ignored.
node_set_features[node_set_name] = _call_model(
model, node_set, logging_name=f"node_set '{node_set_name}'")
edge_set_features = {}
for edge_set_name, edge_set in graph.edge_sets.items():
try:
model = self._edge_set_models[edge_set_name]
except KeyError as e:
raise KeyError(f"Unexpected edge set '{edge_set_name}' "
"not seen in first call") from e
if model is None: continue # Initialized to be ignored.
edge_set_features[edge_set_name] = _call_model(
model, edge_set, logging_name=f"edge_set '{edge_set_name}'")
result = graph.replace_features(context=context_features,
node_sets=node_set_features,
edge_sets=edge_set_features)
return result