graphlearn_torch/python/utils/common.py (157 lines of code) (raw):
# Copyright 2022 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import os
import socket
from typing import Any, Dict, Callable, Optional, Literal
from ..typing import reverse_edge_type
from .tensor import id2idx
import numpy
import random
import torch
import pickle
def ensure_dir(dir_path: str):
if not os.path.exists(dir_path):
os.makedirs(dir_path)
def seed_everything(seed: int):
r"""Sets the seed for generating random numbers in :pytorch:`PyTorch`,
:obj:`numpy` and :python:`Python`.
Args:
seed (int): The desired seed.
"""
random.seed(seed)
numpy.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def merge_dict(in_dict: Dict[Any, Any], out_dict: Dict[Any, Any]):
for k, v in in_dict.items():
vals = out_dict.get(k, [])
vals.append(v)
out_dict[k] = vals
def count_dict(in_dict: Dict[Any, Any], out_dict: Dict[Any, Any], target_len):
for k, v in in_dict.items():
vals = out_dict.get(k, [])
vals += [0] * (target_len - len(vals) - 1)
vals.append(len(v))
out_dict[k] = vals
def get_free_port(host: str = 'localhost') -> int:
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.bind((host, 0))
port = s.getsockname()[1]
s.close()
return port
def index_select(data, index):
if data is None:
return None
if isinstance(data, dict):
new_data = {}
for k, v in data.items():
new_data[k] = index_select(v, index)
return new_data
if isinstance(data, list):
new_data = []
for v in data:
new_data.append(index_select(v, index))
return new_data
if isinstance(data, tuple):
return tuple(index_select(list(data), index))
if isinstance(index, tuple):
start, end = index
return data[start:end]
return data[index]
def merge_hetero_sampler_output(
in_sample: Any, out_sample: Any, device,
edge_dir: Literal['in', 'out']='out'):
def subid2gid(sample):
for k, v in sample.row.items():
sample.row[k] = sample.node[k[0]][v]
for k, v in sample.col.items():
sample.col[k] = sample.node[k[-1]][v]
def merge_tensor_dict(in_dict, out_dict, unique=False):
for k, v in in_dict.items():
vals = out_dict.get(k, torch.tensor([], device=device))
out_dict[k] = torch.cat((vals, v)).unique() if unique \
else torch.cat((vals, v))
subid2gid(in_sample)
subid2gid(out_sample)
merge_tensor_dict(in_sample.node, out_sample.node, unique=True)
merge_tensor_dict(in_sample.row, out_sample.row)
merge_tensor_dict(in_sample.col, out_sample.col)
for k, v in out_sample.row.items():
out_sample.row[k] = id2idx(out_sample.node[k[0]])[v.to(torch.int64)]
for k, v in out_sample.col.items():
out_sample.col[k] = id2idx(out_sample.node[k[-1]])[v.to(torch.int64)]
# if in_sample.batch is not None and out_sample.batch is not None:
# merge_tensor_dict(in_sample.batch, out_sample.batch)
if in_sample.edge is not None and out_sample.edge is not None:
merge_tensor_dict(in_sample.edge, out_sample.edge, unique=False)
if out_sample.edge_types is not None and in_sample.edge_types is not None:
out_sample.edge_types = list(set(out_sample.edge_types) | set(in_sample.edge_types))
if edge_dir == 'out':
out_sample.edge_types = [
reverse_edge_type(etype) if etype[0] != etype[-1] else etype
for etype in out_sample.edge_types
]
return out_sample
def format_hetero_sampler_output(in_sample: Any, edge_dir=Literal['in', 'out']):
for k in in_sample.node.keys():
in_sample.node[k] = in_sample.node[k].unique()
if in_sample.edge_types is not None:
if edge_dir == 'out':
in_sample.edge_types = [
reverse_edge_type(etype) if etype[0] != etype[-1] else etype
for etype in in_sample.edge_types
]
return in_sample
# Append a tensor to a file using pickle
def append_tensor_to_file(filename, tensor):
# Try to open file in append binary mode
try:
with open(filename, 'ab') as f:
pickle.dump(tensor, f)
except Exception as e:
print('Error:', e)
# Load a file containing tensors and concatenate them into a single tensor
def load_and_concatenate_tensors(filename, device):
# Load file and read tensors
with open(filename, 'rb') as f:
tensor_list = []
while True:
try:
tensor = pickle.load(f)
tensor_list.append(tensor)
except EOFError:
break
# Pre-allocate memory for combined tensor
combined_tensor = torch.empty((sum(t.shape[0] for t in tensor_list),
*tensor_list[0].shape[1:]), dtype=tensor_list[0].dtype, device=device)
# Concatenate tensors in list into combined tensor
start_idx = 0
for tensor in tensor_list:
end_idx = start_idx + tensor.shape[0]
combined_tensor[start_idx:end_idx] = tensor.to(device)
start_idx = end_idx
return combined_tensor
## Default function to select ids in `srcs` that belong to a specific partition
def default_id_select(srcs, p_mask, node_pb=None):
return torch.masked_select(srcs, p_mask)
## Default function to filter src ids in a specific partition from the partition book
def default_id_filter(node_pb, partition_idx):
return torch.where(node_pb == partition_idx)[0]
def save_ckpt(
ckpt_seq: int,
ckpt_dir: str,
model: torch.nn.Module,
optimizer: Optional[torch.optim.Optimizer] = None,
epoch: float = 0,
):
"""
Saves a checkpoint of the model's state.
Parameters:
ckpt_seq (int): The sequence number of the checkpoint.
ckpt_dir (str): The directory where the checkpoint will be saved.
model (torch.nn.Module): The model to be saved.
optimizer (Optional[torch.optim.Optimizer]): The optimizer, if any.
epoch (float): The current epoch. Default is 0.
"""
if not os.path.isdir(ckpt_dir):
os.makedirs(ckpt_dir)
ckpt_path = os.path.join(ckpt_dir, f"model_seq_{ckpt_seq}.ckpt")
ckpt = {
'seq': ckpt_seq,
'epoch': epoch,
'model_state_dict': model.state_dict()
}
if optimizer:
ckpt['optimizer_state_dict'] = optimizer.state_dict()
torch.save(ckpt, ckpt_path)
def load_ckpt(
ckpt_seq: int,
ckpt_dir: str,
model: torch.nn.Module,
optimizer: Optional[torch.optim.Optimizer] = None,
) -> float:
"""
Loads a checkpoint of the model's state, returns the epoch of the checkpoint.
Parameters:
ckpt_seq (int): The sequence number of the checkpoint.
ckpt_dir (str): The directory where the checkpoint will be saved.
model (torch.nn.Module): The model to be saved.
optimizer (Optional[torch.optim.Optimizer]): The optimizer, if any.
"""
ckpt_path = os.path.join(ckpt_dir, f"model_seq_{ckpt_seq}.ckpt")
try:
ckpt = torch.load(ckpt_path)
except FileNotFoundError:
return -1
model.load_state_dict(ckpt['model_state_dict'])
epoch = ckpt.get('epoch')
if optimizer:
optimizer.load_state_dict(ckpt['optimizer_state_dict'])
return epoch