in tzrec/datasets/sampler.py [0:0]
def get(self, input_data: Dict[str, pa.Array]) -> Dict[str, pa.Array]:
"""Sampling method.
Args:
input_data (dict): input data with item_id.
Returns:
Positive and negative sampled feature dict.
"""
ids = _pa_ids_to_npy(input_data[self._item_id_field]).reshape(-1, 1)
batch_size = len(ids)
num_fea = len(self._valid_attr_names[1:])
# positive node.
pos_nodes = self._pos_sampler.get(ids).layer_nodes(1)
# the ids of non-leaf nodes is arranged in ascending order.
pos_non_leaf_ids = np.sort(pos_nodes.ids, axis=1)
pos_ids = np.concatenate((pos_non_leaf_ids, ids), axis=1)
pos_fea_result = self._parse_nodes(pos_nodes)[1:]
# randomly select layers to keep
if self._remain_ratio < 1.0:
remain_layer = np.random.choice(
range(1, self._max_level - 1),
int(round(self._remain_ratio * (self._max_level - 2))),
replace=False,
p=self._remain_p,
)
else:
remain_layer = np.array(range(1, self._max_level - 1))
remain_layer.sort()
if self._remain_ratio < 1.0:
pos_fea_index = np.concatenate(
[
remain_layer - 1 + j * (self._max_level - 2)
for j in range(batch_size)
]
)
pos_fea_result = [
pos_fea_result[i].take(pos_fea_index) for i in range(num_fea)
]
# negative sample layer by layer.
neg_fea_layer = []
for i in np.append(remain_layer, self._max_level - 1):
neg_nodes = self._neg_sampler_list[i - 1].get(
pos_ids[:, i - 1], pos_ids[:, i - 1]
)
features = self._parse_nodes(neg_nodes)[1:]
neg_fea_layer.append(features)
# concatenate the features of each layer and
# ensure that the negative sample features of the same user are adjacent.
neg_fea_result = []
cum_layer_num = np.cumsum(
[0]
+ [
self._layer_num_sample[i] if i in remain_layer else 0
for i in range(self._max_level - 1)
]
)
neg_fea_index = np.concatenate(
[
np.concatenate(
[
np.arange(self._layer_num_sample[i])
+ j * self._layer_num_sample[i]
+ batch_size * cum_layer_num[i]
for i in np.append(remain_layer, self._max_level - 1)
]
)
for j in range(batch_size)
]
)
neg_fea_result = [
pa.concat_arrays([array[i] for array in neg_fea_layer]).take(neg_fea_index)
for i in range(num_fea)
]
pos_result_dict = dict(zip(self._valid_attr_names[1:], pos_fea_result))
neg_result_dict = dict(zip(self._valid_attr_names[1:], neg_fea_result))
return pos_result_dict, neg_result_dict