in tzrec/tools/tdm/gen_tree/tree_search_util.py [0:0]
def save_predict_edge(self) -> None:
"""Save edge info for prediction."""
if self.output_file.startswith("odps://"):
output_path = _add_suffix_to_odps_table(
self.output_file, "_predict_edge_table"
)
writer = create_writer(output_path, **self.dataset_kwargs)
# add a edge from -1 to root for graph-learn to get root node
src_ids = ["-1" if use_hash_node_id() else -1]
dst_ids = [self.root.item_id]
weight = [1.0]
for i in range(self.max_level):
for node in self.level_code[i]:
for child in node.children:
src_ids.append(node.item_id)
dst_ids.append(child.item_id)
weight.append(1.0)
edge_table_dict = OrderedDict()
edge_table_dict["src_id"] = pa.array(src_ids)
edge_table_dict["dst_id"] = pa.array(dst_ids)
edge_table_dict["weight"] = pa.array(weight)
writer.write(edge_table_dict)
writer.close()
else:
with open(
os.path.join(self.output_file, "predict_edge_table.txt"), "w"
) as f:
id_type = "string" if use_hash_node_id() else "int64"
f.write(f"src_id:{id_type}\tdst_id:{id_type}\tweight:float\n")
# add a edge from with id -1 to root for graph-learn to get root node
f.write(f"-1\t{self.root.item_id}\t1.0\n")
for i in range(self.max_level):
for node in self.level_code[i]:
for child in node.children:
f.write(f"{node.item_id}\t{child.item_id}\t{1.0}\n")