hugegraph-ml/src/hugegraph_ml/tasks/node_embed.py (51 lines of code) (raw):
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import dgl
import torch
from dgl import DGLGraph
from torch import nn
from tqdm import trange
from hugegraph_ml.utils.early_stopping import EarlyStopping
class NodeEmbed:
def __init__(self, graph: DGLGraph, model: nn.Module):
self.graph = graph
self._model = model
self._device = ""
self._early_stopping = None
self._check_graph()
def _check_graph(self):
required_node_attrs = ["feat"]
for attr in required_node_attrs:
if attr not in self.graph.ndata:
raise ValueError(f"Graph is missing required node attribute '{attr}' in ndata.")
def train_and_embed(
self,
add_self_loop: bool = True,
lr: float = 1e-3,
weight_decay: float = 0,
n_epochs: int = 200,
patience: int = float("inf"),
gpu: int = -1,
) -> DGLGraph:
# Set device for training
self._device = f"cuda:{gpu}" if gpu != -1 and torch.cuda.is_available() else "cpu"
self._early_stopping = EarlyStopping(patience=patience)
self._model = self._model.to(self._device)
self.graph = self.graph.to(self._device)
# Add self-loop if required
if add_self_loop:
self.graph = dgl.add_self_loop(self.graph)
# Get node features and move to device
feat = self.graph.ndata["feat"].to(self._device)
optimizer = torch.optim.Adam(self._model.parameters(), lr=lr, weight_decay=weight_decay)
# Training model
epochs = trange(n_epochs)
for epoch in epochs:
self._model.train()
optimizer.zero_grad()
# Forward pass and compute loss
loss = self._model(self.graph, feat)
loss.backward()
optimizer.step()
# Log
epochs.set_description(f"epoch {epoch} | train loss {loss.item():.4f}")
# early stop
self._early_stopping(loss.item(), self._model)
torch.cuda.empty_cache()
if self._early_stopping.early_stop:
break
self._early_stopping.load_best_model(self._model)
embed_feat = self._model.get_embedding(self.graph, feat)
self.graph.ndata["feat"] = embed_feat
return self.graph