def save_predict_edge()

in tzrec/tools/tdm/gen_tree/tree_search_util.py [0:0]


    def save_predict_edge(self) -> None:
        """Save edge info for prediction."""
        if self.output_file.startswith("odps://"):
            output_path = _add_suffix_to_odps_table(
                self.output_file, "_predict_edge_table"
            )
            writer = create_writer(output_path, **self.dataset_kwargs)
            # add a edge from -1 to root for graph-learn to get root node
            src_ids = ["-1" if use_hash_node_id() else -1]
            dst_ids = [self.root.item_id]
            weight = [1.0]
            for i in range(self.max_level):
                for node in self.level_code[i]:
                    for child in node.children:
                        src_ids.append(node.item_id)
                        dst_ids.append(child.item_id)
                        weight.append(1.0)
            edge_table_dict = OrderedDict()
            edge_table_dict["src_id"] = pa.array(src_ids)
            edge_table_dict["dst_id"] = pa.array(dst_ids)
            edge_table_dict["weight"] = pa.array(weight)
            writer.write(edge_table_dict)
            writer.close()
        else:
            with open(
                os.path.join(self.output_file, "predict_edge_table.txt"), "w"
            ) as f:
                id_type = "string" if use_hash_node_id() else "int64"
                f.write(f"src_id:{id_type}\tdst_id:{id_type}\tweight:float\n")
                # add a edge from  with id -1 to root for graph-learn to get root node
                f.write(f"-1\t{self.root.item_id}\t1.0\n")
                for i in range(self.max_level):
                    for node in self.level_code[i]:
                        for child in node.children:
                            f.write(f"{node.item_id}\t{child.item_id}\t{1.0}\n")