in tzrec/tools/tdm/gen_tree/tree_cluster.py [0:0]
def _read(self) -> None:
t1 = time.time()
data = list()
self.leaf_nodes = []
selected_cols = (
{self.item_id_field, self.embedding_field}
| set(self.attr_fields)
| set(self.raw_attr_fields)
)
reader = create_reader(
self.item_input_path,
4096,
selected_cols=list(selected_cols),
**self.dataset_kwargs,
)
for data_dict in reader.to_batches():
if use_hash_node_id():
ids = data_dict[self.item_id_field].cast(pa.string()).to_pylist()
else:
ids = data_dict[self.item_id_field].cast(pa.int64()).to_pylist()
data += data_dict[self.embedding_field].to_pylist()
batch_tree_nodes = []
for one_id in ids:
batch_tree_nodes.append(TDMTreeNode(item_id=one_id))
for attr in self.attr_fields:
attr_data = data_dict[attr]
for i in range(len(batch_tree_nodes)):
batch_tree_nodes[i].attrs.append(attr_data[i])
for attr in self.raw_attr_fields:
attr_data = data_dict[attr]
for i in range(len(batch_tree_nodes)):
batch_tree_nodes[i].raw_attrs.append(attr_data[i])
self.leaf_nodes.extend(batch_tree_nodes)
if isinstance(data[0], str):
data = [eval(i) for i in data]
self.data = np.array(data)
t2 = time.time()
logger.info(
"Read data done, {} records read, elapsed: {}".format(
len(self.leaf_nodes), t2 - t1
)
)