tzrec/datasets/sampler.py (705 lines of code) (raw):

# Copyright (c) 2024-2025, Alibaba Group; # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math import os import random import socket import time from typing import Dict, List, Optional, Tuple, Union import graphlearn as gl import numpy as np import numpy.typing as npt import pyarrow as pa import torch from graphlearn.python.data.values import Values from graphlearn.python.nn.pytorch.data.utils import launch_server from torch import distributed as dist from torch.utils.data import get_worker_info from tzrec.protos import sampler_pb2 from tzrec.utils.env_util import use_hash_node_id from tzrec.utils.load_class import get_register_class_meta from tzrec.utils.logging_util import logger from tzrec.utils.misc_util import get_free_port # patch graph-learn string_attrs for utf-8 @property def string_attrs(self): # NOQA self._init() return self._string_attrs # pyre-ignore [56] @string_attrs.setter # pyre-ignore [2, 3] def string_attrs(self, string_attrs): # NOQA self._string_attrs = self._reshape(string_attrs, expand_shape=True) self._inited = True Values.string_attrs = string_attrs def _get_gl_type(field_type: pa.DataType) -> str: type_map = { pa.int32(): "int", pa.int64(): "int", pa.float32(): "float", pa.float64(): "float", } if field_type in type_map: return type_map[field_type] else: return "string" def _get_np_type(field_type: pa.DataType) -> npt.DTypeLike: type_map = { pa.int32(): np.int32, pa.int64(): np.int64, pa.float32(): np.float32, pa.float64(): np.double, } if field_type in type_map: return type_map[field_type] else: return np.str_ def _bootstrap(group_size: int, local_rank: int, group_rank: int) -> str: def addr_to_tensor(ip: str, port: str) -> torch.Tensor: addr_array = [int(i) for i in (ip.split("."))] + [int(port)] addr_tensor = torch.tensor(addr_array, dtype=torch.int) return addr_tensor def tensor_to_addr(tensor: torch.Tensor) -> str: addr_array = tensor.tolist() addr = ".".join([str(i) for i in addr_array[:-1]]) + ":" + str(addr_array[-1]) return addr def exchange_gl_server_info( addr_tensor: torch.Tensor, group_size: int, group_rank: int ) -> str: comm_tensor = torch.zeros([group_size, 5], dtype=torch.int32) comm_tensor[group_rank] = addr_tensor if dist.get_backend() == dist.Backend.NCCL: comm_tensor = comm_tensor.cuda() dist.all_reduce(comm_tensor, op=dist.ReduceOp.MAX) cluster_server_info = ",".join([tensor_to_addr(t) for t in comm_tensor]) return cluster_server_info if local_rank == 0: local_ip = socket.gethostbyname(socket.gethostname()) port = str(get_free_port(local_ip)) else: local_ip = "0.0.0.0" port = "0" if not dist.is_initialized(): # stand-alone return local_ip + ":" + port gl_server_info = exchange_gl_server_info( addr_to_tensor(local_ip, port), group_size, group_rank ) return gl_server_info def _get_cluster_spec(num_client_per_rank: int = 1) -> Dict[str, Union[int, str]]: world_size = int(os.environ.get("WORLD_SIZE", 1)) group_size = world_size // int(os.environ.get("LOCAL_WORLD_SIZE", 1)) local_rank = int(os.environ.get("LOCAL_RANK", 0)) group_rank = int(os.environ.get("GROUP_RANK", 0)) num_client = num_client_per_rank gl_server_info = _bootstrap(group_size, local_rank, group_rank) return {"server": gl_server_info, "client_count": world_size * num_client} _SAMPLER_CLASS_MAP = {} _meta_cls = get_register_class_meta(_SAMPLER_CLASS_MAP) SAMPLER_CFG_TYPES = Union[ sampler_pb2.NegativeSampler, sampler_pb2.NegativeSamplerV2, sampler_pb2.HardNegativeSampler, sampler_pb2.HardNegativeSamplerV2, sampler_pb2.TDMSampler, ] def _to_arrow_array( x: npt.NDArray, field_type: pa.DataType, multival_sep: str = chr(29) ) -> pa.Array: if pa.types.is_list(field_type) or pa.types.is_map(field_type): x = pa.array(x, type=pa.string()) is_empty = pa.compute.equal(x, pa.scalar("")) x = pa.compute.if_else(is_empty, pa.nulls(len(x)), x) if pa.types.is_list(field_type): result = pa.compute.split_pattern(x, pattern=multival_sep).cast( field_type, safe=False ) else: kv = pa.compute.split_pattern_regex(x, pattern=multival_sep) offsets = kv.offsets kv_list = pa.compute.split_pattern(kv.values, ":").values keys = kv_list.take(list(range(0, len(kv_list), 2))).cast( field_type.key_type, safe=False ) items = kv_list.take(list(range(1, len(kv_list), 2))).cast( field_type.item_type, safe=False ) result = pa.MapArray.from_arrays(offsets, keys, items) else: result = pa.array(x, type=field_type) if isinstance(result, pa.ChunkedArray): result = result.combine_chunks() return result def _pa_ids_to_npy(ids: pa.Array) -> npt.NDArray: """Convert pyarrow id array to numpy array.""" if use_hash_node_id(): ids = ids.cast(pa.string()).to_numpy(zero_copy_only=False) else: ids = ids.cast(pa.int64()).fill_null(0).to_numpy() return ids class BaseSampler(metaclass=_meta_cls): """Negative Sampler base class.""" def __init__( self, config: SAMPLER_CFG_TYPES, fields: List[pa.Field], batch_size: int, is_training: bool = True, multival_sep: str = chr(29), ) -> None: self._batch_size = batch_size self._multival_sep = multival_sep self._g = None if hasattr(config, "num_sample"): # pyre-ignore [16] self._num_sample = config.num_sample else: self._num_sample = None if not is_training and config.HasField("num_eval_sample"): self._num_sample = config.num_eval_sample self._cluster = None input_fields = {f.name: f for f in fields} self._attr_names = [] self._attr_types = [] self._attr_gl_types = [] self._attr_np_types = [] self._valid_attr_names = [] self._ignore_attr_names = set() for field_name in config.attr_fields: if field_name in input_fields: field = input_fields[field_name] self._valid_attr_names.append(field.name) else: field = pa.field(name=field_name, type=pa.string()) self._ignore_attr_names.add(field_name) self._attr_names.append(field.name) self._attr_types.append(field.type) self._attr_gl_types.append(_get_gl_type(field.type)) self._attr_np_types.append(_get_np_type(field.type)) if len(self._ignore_attr_names) > 0: logger.warning( f"Features {self._ignore_attr_names} in " # pyre-ignore [16] f"{self.__class__.__name__} will be ignored." ) if config.HasField("field_delimiter"): gl.set_field_delimiter(config.field_delimiter) if use_hash_node_id(): gl.set_use_string_hash_id(1) self._num_client_per_rank = 1 self._client_id_bias = 0 def init_cluster( self, num_client_per_rank: int = 1, client_id_bias: int = 0, cluster: Optional[Dict[str, Union[int, str]]] = None, ) -> None: """Set client in cluster info.""" gl.set_load_graph_thread_num(max(num_client_per_rank // 2, 1)) self._num_client_per_rank = num_client_per_rank self._client_id_bias = client_id_bias if cluster: self._cluster = cluster else: self._cluster = _get_cluster_spec(self._num_client_per_rank) def launch_server(self) -> None: """Launch sampler server.""" assert self._cluster, "should init cluster first." gl.set_tracker_mode(0) if int(os.environ.get("LOCAL_RANK", 0)) == 0: launch_server(self._g, self._cluster, int(os.environ.get("GROUP_RANK", 0))) def init(self, client_id: int = -1) -> None: """Init sampler client and samplers.""" gl.set_tracker_mode(0) assert self._cluster, "should init cluster first." if client_id < 0: worker_info = get_worker_info() if worker_info is None: client_id = 0 else: client_id = worker_info.id client_id += self._client_id_bias task_index = ( self._num_client_per_rank * int(os.environ.get("RANK", 0)) + client_id ) # print(f"Init task {task_index} in cluster {self._cluster}") self._g.init(task_index=task_index, job_name="client", cluster=self._cluster) def __del__(self) -> None: if self._g is not None: self._g.close() def _parse_nodes(self, nodes: gl.Nodes) -> List[pa.Array]: features = [] int_idx = 0 float_idx = 0 string_idx = 0 for attr_name, attr_type, attr_gl_type, attr_np_type in zip( self._attr_names, self._attr_types, self._attr_gl_types, self._attr_np_types ): if attr_name in self._ignore_attr_names: string_idx += 1 continue if attr_gl_type == "int": feature = nodes.int_attrs[:, :, int_idx] int_idx += 1 elif attr_gl_type == "float": feature = nodes.float_attrs[:, :, float_idx] float_idx += 1 elif attr_gl_type == "string": feature = nodes.string_attrs[:, :, string_idx].astype(np.string_) feature = np.char.decode(feature, "utf-8") string_idx += 1 else: raise ValueError("Unknown attr type %s" % attr_gl_type) feature = np.reshape(feature, [-1])[: self._num_sample].astype(attr_np_type) feature = _to_arrow_array(feature, attr_type) features.append(feature) return features def _parse_sparse_nodes( self, nodes: gl.Nodes ) -> Tuple[List[pa.Array], npt.NDArray]: features = [] int_idx = 0 float_idx = 0 string_idx = 0 for attr_name, attr_type, attr_gl_type, attr_np_type in zip( self._attr_names, self._attr_types, self._attr_gl_types, self._attr_np_types ): if attr_name in self._ignore_attr_names: string_idx += 1 continue if attr_gl_type == "int": feature = nodes.int_attrs[:, int_idx] int_idx += 1 elif attr_gl_type == "float": feature = nodes.float_attrs[:, float_idx] float_idx += 1 elif attr_gl_type == "string": feature = nodes.string_attrs[:, string_idx].astype(np.string_) feature = np.char.decode(feature, "utf-8") string_idx += 1 else: raise ValueError("Unknown attr type %s" % attr_gl_type) feature = feature.astype(attr_np_type) feature = _to_arrow_array(feature, attr_type) features.append(feature) # pyre-ignore [16] return features, nodes.indices @property def estimated_sample_num(self) -> int: """Max number of sampled num examples.""" raise NotImplementedError class NegativeSampler(BaseSampler): """Negative Sampler. Weighted random sampling items not in batch. Args: config (NegativeSampler): negative sampler config. fields (list): item input fields. batch_size (int): mini-batch size. is_training (bool): train or eval. multival_sep (str): multi value separator. """ def __init__( self, config: sampler_pb2.NegativeSampler, fields: List[pa.Field], batch_size: int, is_training: bool = True, multival_sep: str = chr(29), ) -> None: super(NegativeSampler, self).__init__( config, fields, batch_size, is_training, multival_sep ) self._g = gl.Graph().node( config.input_path, node_type="item", decoder=gl.Decoder( attr_types=self._attr_gl_types, weighted=True, attr_delimiter=config.attr_delimiter, ), ) self._item_id_field = config.item_id_field self._sampler = None self.item_id_delim = config.item_id_delim def init(self, client_id: int = -1) -> None: """Init sampler client and samplers.""" super().init(client_id) expand_factor = int(math.ceil(self._num_sample / self._batch_size)) self._sampler = self._g.negative_sampler( "item", expand_factor, strategy="node_weight" ) def get(self, input_data: Dict[str, pa.Array]) -> Dict[str, pa.Array]: """Sampling method. Args: input_data (dict): input data with item_id. Returns: Negative sampled feature dict. """ ids = _pa_ids_to_npy(input_data[self._item_id_field]) ids = np.pad(ids, (0, self._batch_size - len(ids)), "edge") nodes = self._sampler.get(ids) features = self._parse_nodes(nodes) result_dict = dict(zip(self._valid_attr_names, features)) return result_dict @property def estimated_sample_num(self) -> int: """Estimated number of sampled num examples.""" return self._num_sample class NegativeSamplerV2(BaseSampler): """Negative Sampler V2. Weighted random sampling items which do not have positive edge with the user. Args: config (NegativeSampler): negative sampler config. fields (list): item input fields. batch_size (int): mini-batch size. is_training (bool): train or eval. multival_sep (str): multi value separator. """ def __init__( self, config: sampler_pb2.NegativeSamplerV2, fields: List[pa.Field], batch_size: int, is_training: bool = True, multival_sep: str = chr(29), ) -> None: super(NegativeSamplerV2, self).__init__( config, fields, batch_size, is_training, multival_sep ) self._g = ( gl.Graph() .node( config.user_input_path, node_type="user", decoder=gl.Decoder(weighted=True), ) .node( config.item_input_path, node_type="item", decoder=gl.Decoder( attr_types=self._attr_gl_types, weighted=True, attr_delimiter=config.attr_delimiter, ), ) .edge( config.pos_edge_input_path, edge_type=("user", "item", "edge"), decoder=gl.Decoder(weighted=True), ) ) self._item_id_field = config.item_id_field self._user_id_field = config.user_id_field self._sampler = None def init(self, client_id: int = -1) -> None: """Init sampler client and samplers.""" super().init(client_id) expand_factor = int(math.ceil(self._num_sample / self._batch_size)) self._sampler = self._g.negative_sampler( "edge", expand_factor, strategy="random", conditional=True ) # prevent gl timeout worker_info = get_worker_info() num_workers = worker_info.num_workers if worker_info else 1 local_world_size = int(os.environ.get("LOCAL_WORLD_SIZE", 1)) time.sleep(random.randint(0, num_workers * local_world_size)) self.get( {self._user_id_field: pa.array([0]), self._item_id_field: pa.array([0])} ) def get(self, input_data: Dict[str, pa.Array]) -> Dict[str, pa.Array]: """Sampling method. Args: input_data (dict): input data with user_id and item_id. Returns: Negative sampled feature dict. """ src_ids = _pa_ids_to_npy(input_data[self._user_id_field]) dst_ids = _pa_ids_to_npy(input_data[self._item_id_field]) src_ids = np.pad(src_ids, (0, self._batch_size - len(src_ids)), "edge") dst_ids = np.pad(dst_ids, (0, self._batch_size - len(dst_ids)), "edge") nodes = self._sampler.get(src_ids, dst_ids) features = self._parse_nodes(nodes) result_dict = dict(zip(self._valid_attr_names, features)) return result_dict @property def estimated_sample_num(self) -> int: """Estimated number of sampled num examples.""" return self._num_sample class HardNegativeSampler(BaseSampler): """HardNegativeSampler. Weighted random sampling items not in batch as negative samples, and sampling destination nodes in hard_neg_edge as hard negative samples Args: config (NegativeSampler): negative sampler config. fields (list): item input fields. batch_size (int): mini-batch size. is_training (bool): train or eval. multival_sep (str): multi value separator. """ def __init__( self, config: sampler_pb2.HardNegativeSampler, fields: List[pa.Field], batch_size: int, is_training: bool = True, multival_sep: str = chr(29), ) -> None: super(HardNegativeSampler, self).__init__( config, fields, batch_size, is_training, multival_sep ) self._num_hard_sample = config.num_hard_sample self._g = ( gl.Graph() .node( config.user_input_path, node_type="user", decoder=gl.Decoder(weighted=True), ) .node( config.item_input_path, node_type="item", decoder=gl.Decoder( attr_types=self._attr_gl_types, weighted=True, attr_delimiter=config.attr_delimiter, ), ) .edge( config.hard_neg_edge_input_path, edge_type=("user", "item", "hard_neg_edge"), decoder=gl.Decoder(weighted=True), ) ) self._item_id_field = config.item_id_field self._user_id_field = config.user_id_field self._neg_sampler = None self._hard_neg_sampler = None def init(self, client_id: int = -1) -> None: """Init sampler client and samplers.""" super().init(client_id) expand_factor = int(math.ceil(self._num_sample / self._batch_size)) self._neg_sampler = self._g.negative_sampler( "item", expand_factor, strategy="node_weight" ) self._hard_neg_sampler = self._g.neighbor_sampler( ["hard_neg_edge"], self._num_hard_sample, strategy="full" ) def get(self, input_data: Dict[str, pa.Array]) -> Dict[str, pa.Array]: """Sampling method. Args: input_data (dict): input data with user_id and item_id. Returns: Negative sampled feature dict. The first batch_size is negative samples, remainder is hard negative samples """ src_ids = _pa_ids_to_npy(input_data[self._user_id_field]) dst_ids = _pa_ids_to_npy(input_data[self._item_id_field]) dst_ids = np.pad(dst_ids, (0, self._batch_size - len(dst_ids)), "edge") nodes = self._neg_sampler.get(dst_ids) neg_features = self._parse_nodes(nodes) sparse_nodes = self._hard_neg_sampler.get(src_ids).layer_nodes(1) hard_neg_features, hard_neg_indices = self._parse_sparse_nodes(sparse_nodes) results = [] for i, v in enumerate(hard_neg_features): results.append(pa.concat_arrays([neg_features[i], v])) result_dict = dict(zip(self._valid_attr_names, results)) result_dict["hard_neg_indices"] = pa.array(hard_neg_indices) return result_dict @property def estimated_sample_num(self) -> int: """Estimated number of sampled num examples.""" return self._num_sample + min(self._num_hard_sample, 8) * self._batch_size class HardNegativeSamplerV2(BaseSampler): """HardNegativeSampler. Weighted random sampling items which do not have positive edge with the user, and sampling destination nodes in hard_neg_edge as hard negative samples. Args: config (NegativeSampler): negative sampler config. fields (list): item input fields. batch_size (int): mini-batch size. is_training (bool): train or eval. multival_sep (str): multi value separator. """ def __init__( self, config: sampler_pb2.HardNegativeSamplerV2, fields: List[pa.Field], batch_size: int, is_training: bool = True, multival_sep: str = chr(29), ) -> None: super(HardNegativeSamplerV2, self).__init__( config, fields, batch_size, is_training, multival_sep ) self._num_hard_sample = config.num_hard_sample self._g = ( gl.Graph() .node( config.user_input_path, node_type="user", decoder=gl.Decoder(weighted=True), ) .node( config.item_input_path, node_type="item", decoder=gl.Decoder( attr_types=self._attr_gl_types, weighted=True, attr_delimiter=config.attr_delimiter, ), ) .edge( config.pos_edge_input_path, edge_type=("user", "item", "edge"), decoder=gl.Decoder(weighted=True), ) .edge( config.hard_neg_edge_input_path, edge_type=("user", "item", "hard_neg_edge"), decoder=gl.Decoder(weighted=True), ) ) self._item_id_field = config.item_id_field self._user_id_field = config.user_id_field self._neg_sampler = None self._hard_neg_sampler = None def init(self, client_id: int = -1) -> None: """Init sampler client and samplers.""" super().init(client_id) expand_factor = int(math.ceil(self._num_sample / self._batch_size)) self._neg_sampler = self._g.negative_sampler( "edge", expand_factor, strategy="random", conditional=True ) self._hard_neg_sampler = self._g.neighbor_sampler( ["hard_neg_edge"], self._num_hard_sample, strategy="full" ) def get(self, input_data: Dict[str, pa.Array]) -> Dict[str, pa.Array]: """Sampling method. Args: input_data (dict): input data with user_id and item_id. Returns: Negative sampled feature dict. The first batch_size is negative samples, remainder is hard negative samples """ src_ids = _pa_ids_to_npy(input_data[self._user_id_field]) dst_ids = _pa_ids_to_npy(input_data[self._item_id_field]) padded_src_ids = np.pad(src_ids, (0, self._batch_size - len(src_ids)), "edge") dst_ids = np.pad(dst_ids, (0, self._batch_size - len(dst_ids)), "edge") nodes = self._neg_sampler.get(padded_src_ids, dst_ids) neg_features = self._parse_nodes(nodes) sparse_nodes = self._hard_neg_sampler.get(src_ids).layer_nodes(1) hard_neg_features, hard_neg_indices = self._parse_sparse_nodes(sparse_nodes) results = [] for i, v in enumerate(hard_neg_features): results.append(pa.concat_arrays([neg_features[i], v])) result_dict = dict(zip(self._valid_attr_names, results)) result_dict["hard_neg_indices"] = pa.array(hard_neg_indices) return result_dict @property def estimated_sample_num(self) -> int: """Estimated number of sampled num examples.""" return self._num_sample + min(self._num_hard_sample, 8) * self._batch_size class TDMSampler(BaseSampler): """TDM training sampler. According to the leaf nodes corresponding to the items clicked by the user, sample all ancestor nodes as positive samples, and then sample negative samples layer by layer. Args: config (NegativeSampler): negative sampler config. fields (list): item input fields. batch_size (int): mini-batch size. is_training (bool): train or eval. multival_sep (str): multi value separator. """ def __init__( self, config: sampler_pb2.TDMSampler, fields: List[pa.Field], batch_size: int, is_training: bool = True, multival_sep: str = chr(29), ) -> None: fields = [pa.field("tree_level", pa.int64())] + fields super().__init__(config, fields, batch_size, is_training, multival_sep) self._g = ( gl.Graph() .node( config.item_input_path, node_type="item", decoder=gl.Decoder( attr_types=self._attr_gl_types, weighted=True, attr_delimiter=config.attr_delimiter, ), ) .edge( config.edge_input_path, edge_type=("item", "item", "ancestor"), decoder=gl.Decoder(weighted=True), ) ) self._item_id_field = config.item_id_field self._max_level = len(config.layer_num_sample) self._layer_num_sample = config.layer_num_sample assert self._layer_num_sample[0] == 0, "sample num of tree root must be 0" self._last_layer_num_sample = config.layer_num_sample[-1] self._pos_sampler = None self._neg_sampler_list = [] self._remain_ratio = config.remain_ratio if self._remain_ratio < 1.0: if config.probability_type == "UNIFORM": p = np.array([1 / (self._max_level - 2)] * (self._max_level - 2)) elif config.probability_type == "ARITHMETIC": p = np.arange(1, self._max_level - 1) / sum( np.arange(1, self._max_level - 1) ) elif config.probability_type == "RECIPROCAL": p = 1 / np.arange(self._max_level - 2, 0, -1) p = p / sum(p) else: raise ValueError( f"probability_type: [{config.probability_type}]" "is not supported now." ) self._remain_p = p def init(self, client_id: int = -1) -> None: """Init sampler client and samplers.""" super().init(client_id) self._pos_sampler = self._g.neighbor_sampler( meta_path=["ancestor"], expand_factor=self._max_level - 2, strategy="random_without_replacement", ) # TODO: only use one conditional smapler for i in range(1, self._max_level): self._neg_sampler_list.append( self._g.negative_sampler( "item", expand_factor=self._layer_num_sample[i], strategy="node_weight", conditional=True, int_cols=[0], int_props=[1], samplewise_unique=True, ) ) # prevent gl timeout worker_info = get_worker_info() num_workers = worker_info.num_workers if worker_info else 1 local_world_size = int(os.environ.get("LOCAL_WORLD_SIZE", 1)) time.sleep(random.randint(0, num_workers * local_world_size)) if use_hash_node_id(): self.get({self._item_id_field: pa.array(["0"], type=np.object_)}) else: self.get({self._item_id_field: pa.array([0])}) def get(self, input_data: Dict[str, pa.Array]) -> Dict[str, pa.Array]: """Sampling method. Args: input_data (dict): input data with item_id. Returns: Positive and negative sampled feature dict. """ ids = _pa_ids_to_npy(input_data[self._item_id_field]).reshape(-1, 1) batch_size = len(ids) num_fea = len(self._valid_attr_names[1:]) # positive node. pos_nodes = self._pos_sampler.get(ids).layer_nodes(1) # the ids of non-leaf nodes is arranged in ascending order. pos_non_leaf_ids = np.sort(pos_nodes.ids, axis=1) pos_ids = np.concatenate((pos_non_leaf_ids, ids), axis=1) pos_fea_result = self._parse_nodes(pos_nodes)[1:] # randomly select layers to keep if self._remain_ratio < 1.0: remain_layer = np.random.choice( range(1, self._max_level - 1), int(round(self._remain_ratio * (self._max_level - 2))), replace=False, p=self._remain_p, ) else: remain_layer = np.array(range(1, self._max_level - 1)) remain_layer.sort() if self._remain_ratio < 1.0: pos_fea_index = np.concatenate( [ remain_layer - 1 + j * (self._max_level - 2) for j in range(batch_size) ] ) pos_fea_result = [ pos_fea_result[i].take(pos_fea_index) for i in range(num_fea) ] # negative sample layer by layer. neg_fea_layer = [] for i in np.append(remain_layer, self._max_level - 1): neg_nodes = self._neg_sampler_list[i - 1].get( pos_ids[:, i - 1], pos_ids[:, i - 1] ) features = self._parse_nodes(neg_nodes)[1:] neg_fea_layer.append(features) # concatenate the features of each layer and # ensure that the negative sample features of the same user are adjacent. neg_fea_result = [] cum_layer_num = np.cumsum( [0] + [ self._layer_num_sample[i] if i in remain_layer else 0 for i in range(self._max_level - 1) ] ) neg_fea_index = np.concatenate( [ np.concatenate( [ np.arange(self._layer_num_sample[i]) + j * self._layer_num_sample[i] + batch_size * cum_layer_num[i] for i in np.append(remain_layer, self._max_level - 1) ] ) for j in range(batch_size) ] ) neg_fea_result = [ pa.concat_arrays([array[i] for array in neg_fea_layer]).take(neg_fea_index) for i in range(num_fea) ] pos_result_dict = dict(zip(self._valid_attr_names[1:], pos_fea_result)) neg_result_dict = dict(zip(self._valid_attr_names[1:], neg_fea_result)) return pos_result_dict, neg_result_dict @property def estimated_sample_num(self) -> int: """Estimated number of sampled num examples.""" return ( sum(self._layer_num_sample) + len(self._layer_num_sample) - 2 ) * self._batch_size class TDMPredictSampler(BaseSampler): """TDM predict sampler. Args: config (NegativeSampler): negative sampler config. fields (list): item input fields. batch_size (int): mini-batch size. is_training (bool): train or eval. multival_sep (str): multi value separator. """ def __init__( self, config: sampler_pb2.TDMSampler, fields: List[pa.Field], batch_size: int, is_training: bool = True, multival_sep: str = chr(29), ) -> None: fields = [pa.field("tree_level", pa.int64())] + fields super().__init__(config, fields, batch_size, is_training, multival_sep) self._g = ( gl.Graph() .node( config.item_input_path, node_type="item", decoder=gl.Decoder( attr_types=self._attr_gl_types, weighted=True, attr_delimiter=config.attr_delimiter, ), ) .edge( config.predict_edge_input_path, edge_type=("item", "item", "children"), decoder=gl.Decoder(weighted=True), ) ) self._item_id_field = config.item_id_field self._max_level = len(config.layer_num_sample) self._pos_sampler = None def init_sampler(self, expand_factor: int) -> None: """Init samplers with different expand_factor. During prediction, the first sampling selects all nodes from the first layer larger than the recall number, starting from the root node. Then, for each node, all its child nodes are sampled. The expand_factor is different in the two rounds of sampling. """ self._pos_sampler = self._g.neighbor_sampler( meta_path=["children"], expand_factor=expand_factor, strategy="random_without_replacement", ) def get(self, input_data: Dict[str, pa.Array]) -> Dict[str, pa.Array]: """Sampling method. Args: input_data (dict): input data with item_id. Returns: Positive and negative sampled feature dict. """ ids = _pa_ids_to_npy(input_data[self._item_id_field]).reshape(-1, 1) pos_nodes = self._pos_sampler.get(ids).layer_nodes(1) pos_fea_result = self._parse_nodes(pos_nodes)[1:] pos_result_dict = dict(zip(self._valid_attr_names[1:], pos_fea_result)) return pos_result_dict @property def estimated_sample_num(self) -> int: """Estimated number of sampled num examples.""" return min((2 ** (self._max_level - 1)), 800) * self._batch_size