in tzrec/tools/tdm/gen_tree/tree_generator.py [0:0]
def _read(self) -> List[TDMTreeNode]:
leaf_nodes = []
selected_cols = (
{self.item_id_field, self.cate_id_field}
| set(self.attr_fields)
| set(self.raw_attr_fields)
)
reader = create_reader(
self.item_input_path,
4096,
selected_cols=list(selected_cols),
**self.dataset_kwargs,
)
for data_dict in reader.to_batches():
if use_hash_node_id():
ids = data_dict[self.item_id_field].cast(pa.string()).to_pylist()
else:
ids = data_dict[self.item_id_field].cast(pa.int64()).to_pylist()
cates = (
data_dict[self.cate_id_field]
.cast(pa.string())
.fill_null("")
.to_pylist()
)
batch_tree_nodes = []
for one_id, one_cate in zip(ids, cates):
batch_tree_nodes.append(TDMTreeNode(item_id=one_id, cate=one_cate))
for attr in self.attr_fields:
attr_data = data_dict[attr]
for i in range(len(batch_tree_nodes)):
batch_tree_nodes[i].attrs.append(attr_data[i])
for attr in self.raw_attr_fields:
attr_data = data_dict[attr]
for i in range(len(batch_tree_nodes)):
batch_tree_nodes[i].raw_attrs.append(attr_data[i])
leaf_nodes.extend(batch_tree_nodes)
return leaf_nodes