def save()

in tzrec/tools/tdm/gen_tree/tree_search_util.py [0:0]


    def save(self, attr_delimiter: str = ",") -> None:
        """Save tree info."""
        if self.output_file.startswith("odps://"):
            output_path = _add_suffix_to_odps_table(self.output_file, "_node_table")
            node_writer = create_writer(output_path, **self.dataset_kwargs)
            ids = []
            weight = []
            features = []
            first_node = True
            for level, nodes in enumerate(self.level_code):
                for node in nodes:
                    fea = [level, node.item_id]
                    if node.attrs:
                        fea.append(
                            attr_delimiter.join(
                                map(lambda x: str(x) if x.is_valid else "", node.attrs)
                            )
                        )
                    if node.raw_attrs:
                        fea.append(
                            attr_delimiter.join(
                                map(
                                    lambda x: str(x) if x.is_valid else 0,
                                    node.raw_attrs,
                                )
                            )
                        )

                    # add a node with id -1 for graph-learn to get root node
                    if first_node:
                        ids.append("-1" if use_hash_node_id() else -1)
                        weight.append(1.0)
                        features.append(",".join(["-1"] + list(map(str, fea[1:]))))
                        first_node = False

                    ids.append(node.item_id)
                    weight.append(1.0)
                    features.append(",".join(map(str, fea)))

            node_table_dict = OrderedDict()
            node_table_dict["id"] = pa.array(ids)
            node_table_dict["weight"] = pa.array(weight)
            node_table_dict["features"] = pa.array(features)
            node_writer.write(node_table_dict)
            node_writer.close()

            output_path = _add_suffix_to_odps_table(self.output_file, "_edge_table")
            edge_writer = create_writer(output_path, **self.dataset_kwargs)
            src_ids = []
            dst_ids = []
            weight = []
            for travel in self.travel_list:
                # do not include edge from leaf to root
                for i in range(self.max_level - 1):
                    src_ids.append(travel[0])
                    dst_ids.append(travel[i + 1])
                    weight.append(1.0)
            edge_table_dict = OrderedDict()
            edge_table_dict["src_id"] = pa.array(src_ids)
            edge_table_dict["dst_id"] = pa.array(dst_ids)
            edge_table_dict["weight"] = pa.array(weight)
            edge_writer.write(edge_table_dict)
            edge_writer.close()

        else:
            if not os.path.exists(self.output_file):
                os.makedirs(self.output_file)
            with open(os.path.join(self.output_file, "node_table.txt"), "w") as f:
                id_type = "string" if use_hash_node_id() else "int64"
                f.write(f"id:{id_type}\tweight:float\tfeature:string\n")
                first_node = True
                for level, nodes in enumerate(self.level_code):
                    for node in nodes:
                        fea = [level, node.item_id]
                        if node.attrs:
                            fea.append(
                                attr_delimiter.join(
                                    map(
                                        lambda x: str(x) if x.is_valid else "",
                                        node.attrs,
                                    )
                                )
                            )
                        if node.raw_attrs:
                            fea.append(
                                attr_delimiter.join(
                                    map(
                                        lambda x: str(x) if x.is_valid else 0,
                                        node.raw_attrs,
                                    )
                                )
                            )
                        # add a node with id -1 for graph-learn to get root node
                        if first_node:
                            f.write(f"-1\t1.0\t-1,{','.join(map(str, fea[1:]))}\n")
                            first_node = False
                        f.write(f"{node.item_id}\t1.0\t{','.join(map(str, fea))}\n")

            with open(os.path.join(self.output_file, "edge_table.txt"), "w") as f:
                id_type = "string" if use_hash_node_id() else "int64"
                f.write(f"src_id:{id_type}\tdst_id:{id_type}\tweight:float\n")
                for travel in self.travel_list:
                    # do not include edge from leaf to root
                    for i in range(self.max_level - 1):
                        f.write(f"{travel[0]}\t{travel[i + 1]}\t{1.0}\n")