tzrec/benchmark/benchmark.py (261 lines of code) (raw):
# Copyright (c) 2024, Alibaba Group;
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import json
import os
import time
from typing import Any, Dict, List, Tuple
from tzrec.tests.utils import _standalone
from tzrec.utils import config_util, misc_util
TEXT_RESET = "\033[0m"
TEXT_BOLD_RED = "\033[1;31m"
TEXT_BOLD_GREEN = "\033[1;32m"
TEXT_BOLD_YELLOW = "\033[1;33m"
TEXT_BOLD_BLUE = "\033[1;34m"
TEXT_BOLD_CYAN = "\033[1;36m"
# pyre-ignore [2,3]
def print_error(*args, **kwargs):
"""Print error info."""
print(f"{TEXT_BOLD_RED}[ERROR]{TEXT_RESET}", *args, **kwargs)
# pyre-ignore [2,3]
def print_worse(*args, **kwargs):
"""Print train eval metric all worse info."""
print(f"{TEXT_BOLD_BLUE}[WORSE]{TEXT_RESET}", *args, **kwargs)
# pyre-ignore [2,3]
def print_better(*args, **kwargs):
"""Print train eval metric all better info."""
print(f"{TEXT_BOLD_GREEN}[BETTER]{TEXT_RESET}", *args, **kwargs)
# pyre-ignore [2,3]
def print_some_better_and_worse(*args, **kwargs):
"""Has some better and worse metric info."""
print(f"{TEXT_BOLD_CYAN}[HAS BETTER AND WORSE]{TEXT_RESET}", *args, **kwargs)
# pyre-ignore [2,3]
def print_balance(*args, **kwargs):
"""Train metric not much change."""
print(f"{TEXT_BOLD_YELLOW}[BALANCE]{TEXT_RESET}", *args, **kwargs)
def _get_benchmark_project() -> str:
"""Get ODPS project for benchmark."""
project = os.environ.get("CI_ODPS_PROJECT_NAME", "")
if "ODPS_CONFIG_FILE_PATH" in os.environ:
with open(os.environ["ODPS_CONFIG_FILE_PATH"], "r") as f:
for line in f.readlines():
values = line.split("=", 1)
if len(values) == 2 and values[0] == "project_name":
project = values[1].strip()
return project
def _modify_pipline_config(
pipeline_config_path: str,
model_path: str,
run_config_path: str,
) -> None:
pipeline_config = config_util.load_pipeline_config(pipeline_config_path)
pipeline_config.model_dir = model_path
project = _get_benchmark_project()
train_input_path = pipeline_config.train_input_path.format(PROJECT=project)
pipeline_config.train_input_path = train_input_path
eval_input_path = pipeline_config.eval_input_path.format(PROJECT=project)
pipeline_config.eval_input_path = eval_input_path
if pipeline_config.data_config.HasField("negative_sampler"):
sampler = pipeline_config.data_config.negative_sampler
sampler.input_path = sampler.input_path.format(PROJECT=project)
elif pipeline_config.data_config.HasField("negative_sampler_v2"):
sampler = pipeline_config.data_config.negative_sampler_v2
sampler.user_input_path = sampler.user_input_path.format(PROJECT=project)
sampler.item_input_path = sampler.item_input_path.format(PROJECT=project)
sampler.pos_edge_input_path = sampler.pos_edge_input_path.format(
PROJECT=project
)
elif pipeline_config.data_config.HasField("hard_negative_sampler"):
sampler = pipeline_config.data_config.hard_negative_sampler
sampler.user_input_path = sampler.user_input_path.format(PROJECT=project)
sampler.item_input_path = sampler.item_input_path.format(PROJECT=project)
sampler.hard_neg_edge_input_path = sampler.hard_neg_edge_input_path.format(
PROJECT=project
)
elif pipeline_config.data_config.HasField("hard_negative_sampler_v2"):
sampler = pipeline_config.data_config.hard_negative_sampler_v2
sampler.user_input_path = sampler.user_input_path.format(PROJECT=project)
sampler.item_input_path = sampler.item_input_path.format(PROJECT=project)
sampler.pos_edge_input_path = sampler.pos_edge_input_path.format(
PROJECT=project
)
sampler.hard_neg_edge_input_path = sampler.hard_neg_edge_input_path.format(
PROJECT=project
)
elif pipeline_config.data_config.HasField("tdm_sampler"):
sampler = pipeline_config.data_config.tdm_sampler
sampler.item_input_path = sampler.item_input_path.format(PROJECT=project)
sampler.edge_input_path = sampler.edge_input_path.format(PROJECT=project)
sampler.predict_edge_input_path = sampler.predict_edge_input_path.format(
PROJECT=project
)
config_util.save_message(pipeline_config, run_config_path)
def _benchmark_train_eval(
run_config_path: str,
log_path: str,
) -> bool:
"""Run train_eval for benchmark."""
cmd_str = (
f"PYTHONPATH=. torchrun {_standalone()} "
"--nnodes=1 --nproc-per-node=2 "
f"--log_dir {log_path} -r 3 -t 3 tzrec/train_eval.py "
f"--pipeline_config_path {run_config_path}"
)
return misc_util.run_cmd(cmd_str, log_path + ".log", timeout=6000)
def _get_config_paths(pipeline_config_paths: str) -> List[str]:
"""Get dir all pipeline config path."""
config_paths = []
if os.path.isfile(pipeline_config_paths):
config_paths.append(pipeline_config_paths)
elif os.path.isdir(pipeline_config_paths):
for root, _, files in os.walk(pipeline_config_paths):
for file in files:
if "base_eval_metric.json" != file:
config_paths.append(os.path.join(root, file))
else:
raise Exception(f"{pipeline_config_paths} is not a valid file or directory")
return config_paths
def _create_directory(path: str) -> str:
"""Create the directory if it doesn't exist."""
if not os.path.exists(path):
os.makedirs(path)
return path
def _get_train_metrics(path: str) -> Dict[str, Any]:
"""From model path we get eval metrics."""
eval_file = os.path.join(path, "train_eval_result.txt")
f = open(eval_file)
metrics = json.load(f)
return metrics
def _compare_metrics(
metric_config: Dict[str, Any], train_metrics: List[Dict[str, Any]]
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
"""Compare model metrics and base metrics."""
base_metric = metric_config["label"]
threshold = metric_config["threshold"]
train_avg_metric = {}
change_metric = {}
better_name = []
worse_name = []
balance_name = []
for k, v in base_metric.items():
if isinstance(threshold, dict):
task_threshold = threshold[k]
else:
task_threshold = threshold
if len(train_metrics) > 0:
train_avg_v = sum([metric[k] for metric in train_metrics]) / len(
train_metrics
)
train_avg_metric[k] = train_avg_v
if train_avg_v - v >= task_threshold:
better_name.append(k)
elif train_avg_v - v <= -task_threshold:
worse_name.append(k)
else:
balance_name.append(k)
change_metric["better"] = better_name
change_metric["worse"] = worse_name
change_metric["balance"] = balance_name
return train_avg_metric, change_metric
def _print(
config_path: str,
run_cnt: int,
fail_cnt: int,
train_avg_metric: Dict[str, Any],
change_metric: Dict[str, Any],
) -> None:
"""Print train metrics."""
better_name = change_metric["better"]
worse_name = change_metric["worse"]
balance_name = change_metric["balance"]
success_cnt = run_cnt - fail_cnt
msg = f"config_path:{config_path}, fail_cnt:{fail_cnt} and run_cnt:{success_cnt}. "
if fail_cnt >= run_cnt - fail_cnt:
msg = msg + f"metric is unbelief, train metric: {train_avg_metric}"
print_error(msg)
elif len(better_name) > 0 and len(worse_name) == 0:
msg = (
msg + f"better metric name:{better_name}, "
f"all train metric is:{train_avg_metric}"
)
print_better(msg)
elif len(worse_name) > 0 and len(better_name) == 0:
msg = (
msg
+ f"worse metric name:{worse_name}, all train metric is:{train_avg_metric}"
)
print_worse(msg)
elif len(worse_name) > 0 and len(better_name) > 0:
msg = (
msg + f"worse metric name:{better_name}, better metric name:{better_name}, "
f"all train metric is:{train_avg_metric}"
)
print_some_better_and_worse(msg)
elif len(balance_name) > 0 and len(worse_name) == 0 and len(better_name) == 0:
msg = (
msg + f"not has better and worse metric name, "
f"all train metric is:{train_avg_metric}"
)
print_balance(msg)
else:
msg = msg + f"all train metric is:{train_avg_metric}"
print(msg)
def main(
pipeline_config_paths: str,
experiment_path: str,
base_metric_path: str = "tzrec/benchmark/configs/base_eval_metric.json",
) -> None:
"""Run benchmarks."""
train_config_paths = _get_config_paths(pipeline_config_paths)
f = open(base_metric_path)
base_eval_metrics = json.load(f)
experiment_path = experiment_path + f"_{int(time.time())}"
print(f"******* We will save experiment is {experiment_path} *******")
models_path = _create_directory(os.path.join(experiment_path, "models"))
configs_path = _create_directory(os.path.join(experiment_path, "configs"))
logs_path = _create_directory(os.path.join(experiment_path, "logs"))
all_train_metrics = {}
all_train_metrics_info = {}
for old_config_path in train_config_paths:
metric_config = base_eval_metrics[old_config_path]
run_cnt = metric_config["run_cnt"]
train_metrics = []
fail_cnt = 0
for i in range(run_cnt):
file_path = (
old_config_path.replace("/", "_")
.replace("\\", "_")
.replace(".config", "")
)
file_path = file_path + f"_{i}"
new_config_path = os.path.join(configs_path, file_path + ".config")
model_path = os.path.join(models_path, file_path)
log_path = os.path.join(logs_path, file_path)
_modify_pipline_config(old_config_path, model_path, new_config_path)
success = _benchmark_train_eval(new_config_path, log_path)
if success:
train_metric = _get_train_metrics(model_path)
train_metrics.append(train_metric)
else:
fail_cnt += 1
train_avg_metric, change_metric = _compare_metrics(metric_config, train_metrics)
_print(old_config_path, run_cnt, fail_cnt, train_avg_metric, change_metric)
all_train_metrics[old_config_path] = train_avg_metric
print_info = {
"run_cnt": run_cnt,
"fail_cnt": fail_cnt,
"train_avg_metric": train_avg_metric,
"change_metric": change_metric,
}
all_train_metrics_info[old_config_path] = print_info
print("".join(["="] * 30))
print("".join(["="] * 30))
for old_config_path, print_info in all_train_metrics_info.items():
run_cnt = print_info["run_cnt"]
fail_cnt = print_info["fail_cnt"]
train_avg_metric = print_info["train_avg_metric"]
change_metric = print_info["change_metric"]
_print(old_config_path, run_cnt, fail_cnt, train_avg_metric, change_metric)
benchmark_file = os.path.join(experiment_path, "benchmark_eval.txt")
with open(benchmark_file, "w") as f:
json.dump(all_train_metrics, f)
print("benchmark complete !!!")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--pipeline_config_path",
type=str,
default="tzrec/benchmark/configs",
help="Path to pipeline config file.",
)
parser.add_argument(
"--experiment_path",
type=str,
default="tmp",
help="Path to experiment model save.",
)
args, extra_args = parser.parse_known_args()
main(
args.pipeline_config_path,
args.experiment_path,
)