in src/nanotron/logging/base.py [0:0]
def set_logger_verbosity_format(logging_level: str, parallel_context: ParallelContext):
# 1. Conditional rank display - only show ranks if their size is > 1
node_name = os.environ.get("SLURMD_NODENAME")
ranks = []
if parallel_context.expert_parallel_size > 1:
ranks.append(f"EP={dist.get_rank(parallel_context.ep_pg)}")
if parallel_context.context_parallel_size > 1:
ranks.append(f"CP={dist.get_rank(parallel_context.cp_pg)}")
if parallel_context.data_parallel_size > 1:
ranks.append(f"DP={dist.get_rank(parallel_context.dp_pg)}")
if parallel_context.pipeline_parallel_size > 1:
ranks.append(f"PP={dist.get_rank(parallel_context.pp_pg)}")
if parallel_context.tensor_parallel_size > 1:
ranks.append(f"TP={dist.get_rank(parallel_context.tp_pg)}")
if node_name:
ranks.append(node_name)
# Join all ranks with separator
ranks_str = "|".join(ranks)
ranks_display = f"|{ranks_str}" if ranks_str else ""
# Use a custom formatter class that handles missing fields
class SafeFormatter(Formatter):
def format(self, record):
# Ensure required attributes exist before formatting
if not hasattr(record, "category"):
record.category = ""
elif record.category and not record.category.startswith("|"):
record.category = f"|{record.category}"
# Store original message for restoration later
original_msg = record.msg
# Apply styling based on record properties
is_separator = getattr(record, "separator", False)
if is_separator:
record.msg = f"\033[1m{record.msg}\033[0m" # Bold for separators
# Choose color prefix/suffix based on log level
if record.levelno == logging.WARNING:
prefix = "\033[1;33m" # Bold yellow for warnings
elif record.levelno >= logging.ERROR:
prefix = "\033[1;31m" # Bold red for errors and critical
elif record.levelno == logging.DEBUG:
prefix = "\033[2;3;32m" # Dim and italic green for debug
else:
prefix = "\033[2;3m" # Dim and italic for other levels
suffix = "\033[0m"
# Save the original format
original_fmt = self._style._fmt
# Use a more consistent format with prefix/suffix applied only to the metadata portion
self._style._fmt = f"{prefix}%(asctime)s [%(levelname)s%(category)s{ranks_display}]{suffix}: %(message)s"
# Format the record
result = super().format(record)
# Restore the original values
self._style._fmt = original_fmt
record.msg = original_msg
return result
# Create formatter with the safe handling
formatter = SafeFormatter(
fmt=f"\033[2;3m%(asctime)s [%(levelname)s%(category)s{ranks_display}]\033[0m: %(message)s",
datefmt="%m/%d %H:%M:%S",
)
log_level = log_levels[logging_level]
# main root logger
root_logger = get_logger()
root_logger.setLevel(log_level)
handler = NewLineStreamHandler(sys.stdout)
handler.setLevel(log_level)
handler.setFormatter(formatter)
root_logger.addHandler(handler)
# Nanotron
set_verbosity(log_level)
set_formatter(formatter=formatter)