log_analyzer/log.py (258 lines of code) (raw):
"""
Copyright (c) 2021, Alibaba Group;
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import os,math
import pickle
import csv
import dataclasses
import numpy as np
from typing import Union, Dict, List
from utils.utils import CommType, CommGroup
from log_analyzer.utils import convert_size_to_msg, calc_bw_log
import copy
@dataclasses.dataclass
class LogItem:
comm_type: CommType = dataclasses.field(default=None)
comm_group: CommGroup = dataclasses.field(default=None)
comm_group_size: int = dataclasses.field(default=None)
msg_size: float = dataclasses.field(default=0)
stage: str = dataclasses.field(default="")
dst: int = dataclasses.field(default=None)
src: int = dataclasses.field(default=None)
additional: str = dataclasses.field(default="")
_elapsed_time: float = dataclasses.field(default=None)
algbw: float = dataclasses.field(default=None)
busbw: float = dataclasses.field(default=None)
count: float = dataclasses.field(default=1)
@property
def elapsed_time(self) -> float:
return self._elapsed_time
@elapsed_time.setter
def elapsed_time(self, elapsed_time):
self._elapsed_time = elapsed_time
self.algbw, self.busbw = calc_bw_log(
self.comm_type, self.msg_size, elapsed_time, self.comm_group_size
)
def is_epoch_end(self):
return self.comm_type == CommType.epoch_end
def is_workload(self):
return self.elapsed_time is None
def view_as_ds_log(self):
log_str = f"[RANK 0] comm op: {self.comm_type} | comm group: {self.comm_group}"
log_str += " | time (ms): {:.2f}".format(self.elapsed_time)
if self.comm_type == CommType.computation or self.additional == 'overlap':
log_str += " | msg size: " + '0'
log_str += " | algbw (GB): " + '0'
log_str += " | busbw (GB): " + '0'
else:
log_str += " | msg size: " + convert_size_to_msg(self.msg_size)
log_str += " | algbw (GB): {:.2f} ".format(self.algbw)
log_str += " | busbw (GB): {:.2f} ".format(self.busbw)
return log_str
def csv_header(self):
return ",".join([k for k in self.__dict__.keys()])
def view_as_csv_line(self):
return ",".join([str(getattr(self, k)) for k in self.__dict__.keys()])
def __str__(self):
if self.is_workload():
return "None"
return "None"
def _print_stage_log(stage_name: str, stage_count: int, comm_type_info: Dict, primary_key: List[str], agg_key: List[str], performance_key: List[str], busbw_key: List[str]):
header = f"{'Comm_Type':<15} {'Comm_Group':<12} {'Message_Size':<12} {'Count':<12} {'Avg_Elapsed_Time ± Std ':<24} {'Avg_BusBw ± Std':<24}\n"
separator = "-" * len(header) + "\n"
log_str = separator + header + separator
for pkey in sorted(comm_type_info.keys()):
row_str = ""
values = {}
for i, pkey_name in enumerate(primary_key):
value = pkey[i] if pkey_name != "msg_size" else convert_size_to_msg(pkey[i])
values[pkey_name] = value
for key in agg_key:
value = comm_type_info[pkey][key]
value = convert_size_to_msg(value) if key == "msg_size" else f"{value:.2f}"
values[key] = value
for key in performance_key:
performance_value_list = sorted(comm_type_info[pkey][key])
values[f'avg_{key}'] = f"{np.mean(performance_value_list):.2f}±{np.std(performance_value_list):.2f}"
values[f'min_{key}'] = f"{performance_value_list[0]:.2f}"
values[f'max_{key}'] = f"{performance_value_list[-1]:.2f}"
for key in busbw_key:
busbw_value_list = sorted(comm_type_info[pkey][key])
values[f'avg_{key}'] = f"{np.mean(busbw_value_list):.2f}±{np.std(busbw_value_list):.2f}"
row_str += f"{values['comm_type']:<15} {values['comm_group']:<12} {values['msg_size']:<12} {values['count']:<16} {values['avg__elapsed_time']:<24} {values['avg_busbw']:<18}\n"
log_str += row_str
return log_str
def _analyze_stage_log(comm_log: List[Dict], stage: str, comm_info: Dict[str, Dict]):
def __update_info(
info_dict,
log,
primary_key: List[str],
agg_key: List[str],
performance_key: List[str],
busbw_key: List[str],
):
primary_key = tuple(log[key] for key in primary_key)
if primary_key not in info_dict:
info_dict[primary_key] = dict((key, 0) for key in agg_key)
info_dict[primary_key].update(dict((key, []) for key in performance_key))
info_dict[primary_key].update(dict((key, []) for key in busbw_key))
for key in agg_key:
info_dict[primary_key][key] += log[key]
for key in performance_key:
info_dict[primary_key][key].append(log[key])
for key in busbw_key:
info_dict[primary_key][key].append(log[key])
if stage not in comm_info:
comm_info[stage] = {
"count": 0,
"comm_type_info": {},
"detailed_comm_type_info": {},
}
comm_info[stage]["count"] += 1
# key: comm_type, value: count, time_ms
comm_type_info = comm_info[stage]["comm_type_info"]
# key: comm_type, msg_size, value: count, time_ms
detailed_comm_type_info = comm_info[stage]["detailed_comm_type_info"]
for log in comm_log:
if log.comm_type != CommType.computation:
__update_info(
comm_type_info,
log.__dict__,
["comm_type", "comm_group"],
["count", "msg_size"],
["_elapsed_time"],
["busbw"],
)
__update_info(
detailed_comm_type_info,
log.__dict__,
["comm_type", "comm_group", "msg_size"],
["count"],
["_elapsed_time"],
["busbw"],
)
class Log:
def __init__(self) -> None:
self.comm_logs = []
self.comm_log_each_epoch = [[]]
self.epoch_times = []
def add_comm_log(self, comm_log: LogItem):
if (
comm_log.is_epoch_end()
and len(self.comm_logs) > 0
and not self.comm_logs[-1].is_epoch_end()
):
self.comm_logs.append(comm_log)
self.comm_log_each_epoch.append([])
self.epoch_times.append(comm_log.elapsed_time)
return
self.comm_logs.append(comm_log)
self.comm_log_each_epoch[-1].append(comm_log)
def analyze(self, print_fn=print):
comm_info: Dict[str, Dict] = {}
_analyze_stage_log(self.comm_log_each_epoch[0], "init", comm_info)
for e_log in self.comm_log_each_epoch[1:]:
_analyze_stage_log(e_log, "train", comm_info)
for stage in comm_info.keys():
if stage != "init":
stage_count = comm_info[stage]["count"]
comm_type_info = comm_info[stage]["comm_type_info"]
detailed_comm_type_info = comm_info[stage]["detailed_comm_type_info"]
log_str = _print_stage_log(stage, stage_count, detailed_comm_type_info, ["comm_type", "comm_group", "msg_size"], ["count"], ["_elapsed_time"], ["busbw"])
print_fn(f"\n\tDetailed comm info for AICB {stage} stage\n{log_str}")
return comm_info
def dump(self, filename):
default_comm_folder_path = "results/comm_logs/"
if not os.path.exists(default_comm_folder_path):
os.makedirs(default_comm_folder_path, exist_ok=True)
if "." in filename:
filename = filename.split(".")[0]
filename = os.path.join("results/comm_logs/", filename)
csv_filename = filename + "_log.csv"
with open(csv_filename, "w") as f:
f.write(self.comm_logs[0].csv_header() + "\n")
for log_item in self.comm_logs:
log_item_write = copy.deepcopy(log_item)
if(log_item_write.comm_type == CommType.computation):
msg_size_str = "("+' '.join(str(shape).replace(',', '') for shape in log_item_write.msg_size)+")"
log_item_write.msg_size = msg_size_str
f.write(log_item_write.view_as_csv_line() + "\n")
del log_item_write
return csv_filename
@staticmethod
def load(filename):
filename = filename.split(".")
filename[-1] = "pkl"
filename = ".".join(filename)
return pickle.load(open(filename, "rb"))
def _get_elapsed_time(self):
return self.epoch_times
def analyze_time(self, print_fn=print):
self.epoch_times.pop(0)
max_val = max(self.epoch_times)
min_val = min(self.epoch_times)
mean_val = sum(self.epoch_times) / len(self.epoch_times)
variance = sum((x - mean_val) ** 2 for x in self.epoch_times) / len(
self.epoch_times
)
variance = math.sqrt(variance)
sorted_list = sorted(self.epoch_times)
p90_val = sorted_list[int(len(sorted_list) * 0.9)]
p99_val = sorted_list[int(len(sorted_list) * 0.99)]
header = f"{'Init time':<18} {'Max iteration time':<20} {'Min iteration time':<20} {'Avg iteration time':<20} {'P90 iteration time ':<20} {'Iteration time Std ':<20}\n"
separator = "-" * len(header) + "\n"
log_str = separator + header + separator
iteration_result = f"{self.epoch_times[0]:<18.2f} {max_val:<20.2f} {min_val:<20.2f} {mean_val:<20.2f} {p90_val:<20.2f} {variance:<20.2f}\n"
log_str += iteration_result
print_fn(f"\n\tDetailed info for AICB iteration time\n{log_str}")
class Workload:
def __init__(self) -> None:
self.workload = []
def append(self, log_item: Union[LogItem, Dict]):
if isinstance(log_item, LogItem):
self.workload.append(log_item)
return
if "stage" not in log_item:
log_item["stage"] = log_item["operation"] if "operation" in log_item else ""
if "comm_group" not in log_item:
assert (
log_item["comm_type"] == CommType.computation
), "comm_group is required for non-computation comm_type"
log_item["comm_group"] = CommGroup.all
self.workload.append(
LogItem(
comm_type=log_item["comm_type"],
comm_group=log_item["comm_group"],
comm_group_size=log_item["comm_group_size"],
msg_size=log_item["msg_size"],
stage=log_item["stage"],
src=log_item.get("src", None),
dst=log_item.get("dst", None),
additional=log_item.get("additional", None),
)
)
def extend(self, new_workload):
self.workload.extend(new_workload.workload)
def dump(self, filename):
folder_path = os.path.dirname(filename)
if folder_path and not os.path.exists(folder_path):
os.makedirs(folder_path)
default_folder_path = "results/mocked_workload/"
if not os.path.exists(default_folder_path):
os.makedirs(default_folder_path, exist_ok=True)
if "." in filename:
filename = os.path.basename(filename).split(".")[0]
filename = os.path.join("results/mocked_workload/", filename)
csv_filename = filename + "_workload.csv"
with open(csv_filename, "w") as f:
f.write(self.workload[0].csv_header() + "\n")
for log_item in self.workload:
log_item_write = copy.deepcopy(log_item)
if(log_item_write.comm_type == CommType.computation):
msg_size_str = "("+' '.join(str(shape).replace(',', '') for shape in log_item_write.msg_size)+")"
log_item_write.msg_size = msg_size_str
f.write(log_item_write.view_as_csv_line() + "\n")
del log_item_write
print(f"Workload file generated:{csv_filename}")
@staticmethod
def load(filename):
filename = filename.split(".")
filename[-1] = "pkl"
filename = ".".join(filename)
workload, args = pickle.load(open(filename, "rb"))
return workload, args