hugegraph-ml/src/hugegraph_ml/tasks/hetero_sample_embed_gatne.py (94 lines of code) (raw):

# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. import random import dgl import torch from torch import nn from tqdm.auto import tqdm from hugegraph_ml.models.gatne import ( construct_typenodes_from_graph, generate_pairs, NSLoss, NeighborSampler, ) class HeteroSampleEmbedGATNE: def __init__(self, graph, model: nn.Module): self.graph = graph self._model = model self._device = "" def train_and_embed( self, lr: float = 1e-3, n_epochs: int = 200, gpu: int = -1, ): self._device = ( f"cuda:{gpu}" if gpu != -1 and torch.cuda.is_available() else "cpu" ) self._model = self._model.to(self._device) self.graph = self.graph.to(self._device) type_nodes = construct_typenodes_from_graph(self.graph) edge_type_count = len(self.graph.etypes) neighbor_samples = 10 num_walks = 20 num_workers = 4 window_size = 5 batch_size = 64 num_sampled = 5 embedding_size = 200 all_walks = [] for i in range(edge_type_count): nodes = torch.LongTensor(type_nodes[i] * num_walks).to(self._device) traces, _ = dgl.sampling.random_walk( self.graph, nodes, metapath=[self.graph.etypes[i]] * (neighbor_samples - 1), ) all_walks.append(traces) train_pairs = generate_pairs(all_walks, window_size, num_workers) neighbor_sampler = NeighborSampler(self.graph, [neighbor_samples]) train_dataloader = torch.utils.data.DataLoader( train_pairs, batch_size=batch_size, collate_fn=neighbor_sampler.sample, shuffle=True, num_workers=num_workers, pin_memory=True, ) nsloss = NSLoss(self.graph.number_of_nodes(), num_sampled, embedding_size) self._model.to(self._device) nsloss.to(self._device) optimizer = torch.optim.Adam( [{"params": self._model.parameters()}, {"params": nsloss.parameters()}], lr=lr, ) for epoch in range(n_epochs): self._model.train() random.shuffle(train_pairs) data_iter = tqdm( train_dataloader, desc=f"epoch {epoch}", total=(len(train_pairs) + (batch_size - 1)) // batch_size, ) avg_loss = 0.0 for i, (block, head_invmap, tails, block_types) in enumerate(data_iter): optimizer.zero_grad() # embs: [batch_size, edge_type_count, embedding_size] block_types = block_types.to(self._device) embs = self._model(block[0].to(self._device))[head_invmap] embs = embs.gather( 1, block_types.view(-1, 1, 1).expand(embs.shape[0], 1, embs.shape[2]), )[:, 0] loss = nsloss( block[0].dstdata[dgl.NID][head_invmap].to(self._device), embs, tails.to(self._device), ) loss.backward() optimizer.step() avg_loss += loss.item() post_fix = { "epoch": epoch, "iter": i, "avg_loss": avg_loss / (i + 1), "loss": loss.item(), } data_iter.set_postfix(post_fix)