def _read()

in tzrec/tools/tdm/gen_tree/tree_cluster.py [0:0]


    def _read(self) -> None:
        t1 = time.time()
        data = list()

        self.leaf_nodes = []

        selected_cols = (
            {self.item_id_field, self.embedding_field}
            | set(self.attr_fields)
            | set(self.raw_attr_fields)
        )
        reader = create_reader(
            self.item_input_path,
            4096,
            selected_cols=list(selected_cols),
            **self.dataset_kwargs,
        )

        for data_dict in reader.to_batches():
            if use_hash_node_id():
                ids = data_dict[self.item_id_field].cast(pa.string()).to_pylist()
            else:
                ids = data_dict[self.item_id_field].cast(pa.int64()).to_pylist()
            data += data_dict[self.embedding_field].to_pylist()

            batch_tree_nodes = []
            for one_id in ids:
                batch_tree_nodes.append(TDMTreeNode(item_id=one_id))

            for attr in self.attr_fields:
                attr_data = data_dict[attr]
                for i in range(len(batch_tree_nodes)):
                    batch_tree_nodes[i].attrs.append(attr_data[i])

            for attr in self.raw_attr_fields:
                attr_data = data_dict[attr]
                for i in range(len(batch_tree_nodes)):
                    batch_tree_nodes[i].raw_attrs.append(attr_data[i])

            self.leaf_nodes.extend(batch_tree_nodes)

        if isinstance(data[0], str):
            data = [eval(i) for i in data]
        self.data = np.array(data)
        t2 = time.time()

        logger.info(
            "Read data done, {} records read, elapsed: {}".format(
                len(self.leaf_nodes), t2 - t1
            )
        )