in tzrec/tools/tdm/gen_tree/tree_search_util.py [0:0]
def save(self, attr_delimiter: str = ",") -> None:
"""Save tree info."""
if self.output_file.startswith("odps://"):
output_path = _add_suffix_to_odps_table(self.output_file, "_node_table")
node_writer = create_writer(output_path, **self.dataset_kwargs)
ids = []
weight = []
features = []
first_node = True
for level, nodes in enumerate(self.level_code):
for node in nodes:
fea = [level, node.item_id]
if node.attrs:
fea.append(
attr_delimiter.join(
map(lambda x: str(x) if x.is_valid else "", node.attrs)
)
)
if node.raw_attrs:
fea.append(
attr_delimiter.join(
map(
lambda x: str(x) if x.is_valid else 0,
node.raw_attrs,
)
)
)
# add a node with id -1 for graph-learn to get root node
if first_node:
ids.append("-1" if use_hash_node_id() else -1)
weight.append(1.0)
features.append(",".join(["-1"] + list(map(str, fea[1:]))))
first_node = False
ids.append(node.item_id)
weight.append(1.0)
features.append(",".join(map(str, fea)))
node_table_dict = OrderedDict()
node_table_dict["id"] = pa.array(ids)
node_table_dict["weight"] = pa.array(weight)
node_table_dict["features"] = pa.array(features)
node_writer.write(node_table_dict)
node_writer.close()
output_path = _add_suffix_to_odps_table(self.output_file, "_edge_table")
edge_writer = create_writer(output_path, **self.dataset_kwargs)
src_ids = []
dst_ids = []
weight = []
for travel in self.travel_list:
# do not include edge from leaf to root
for i in range(self.max_level - 1):
src_ids.append(travel[0])
dst_ids.append(travel[i + 1])
weight.append(1.0)
edge_table_dict = OrderedDict()
edge_table_dict["src_id"] = pa.array(src_ids)
edge_table_dict["dst_id"] = pa.array(dst_ids)
edge_table_dict["weight"] = pa.array(weight)
edge_writer.write(edge_table_dict)
edge_writer.close()
else:
if not os.path.exists(self.output_file):
os.makedirs(self.output_file)
with open(os.path.join(self.output_file, "node_table.txt"), "w") as f:
id_type = "string" if use_hash_node_id() else "int64"
f.write(f"id:{id_type}\tweight:float\tfeature:string\n")
first_node = True
for level, nodes in enumerate(self.level_code):
for node in nodes:
fea = [level, node.item_id]
if node.attrs:
fea.append(
attr_delimiter.join(
map(
lambda x: str(x) if x.is_valid else "",
node.attrs,
)
)
)
if node.raw_attrs:
fea.append(
attr_delimiter.join(
map(
lambda x: str(x) if x.is_valid else 0,
node.raw_attrs,
)
)
)
# add a node with id -1 for graph-learn to get root node
if first_node:
f.write(f"-1\t1.0\t-1,{','.join(map(str, fea[1:]))}\n")
first_node = False
f.write(f"{node.item_id}\t1.0\t{','.join(map(str, fea))}\n")
with open(os.path.join(self.output_file, "edge_table.txt"), "w") as f:
id_type = "string" if use_hash_node_id() else "int64"
f.write(f"src_id:{id_type}\tdst_id:{id_type}\tweight:float\n")
for travel in self.travel_list:
# do not include edge from leaf to root
for i in range(self.max_level - 1):
f.write(f"{travel[0]}\t{travel[i + 1]}\t{1.0}\n")