def set_logger_verbosity_format()

in src/nanotron/logging/base.py [0:0]


def set_logger_verbosity_format(logging_level: str, parallel_context: ParallelContext):
    # 1. Conditional rank display - only show ranks if their size is > 1
    node_name = os.environ.get("SLURMD_NODENAME")
    ranks = []

    if parallel_context.expert_parallel_size > 1:
        ranks.append(f"EP={dist.get_rank(parallel_context.ep_pg)}")
    if parallel_context.context_parallel_size > 1:
        ranks.append(f"CP={dist.get_rank(parallel_context.cp_pg)}")
    if parallel_context.data_parallel_size > 1:
        ranks.append(f"DP={dist.get_rank(parallel_context.dp_pg)}")
    if parallel_context.pipeline_parallel_size > 1:
        ranks.append(f"PP={dist.get_rank(parallel_context.pp_pg)}")
    if parallel_context.tensor_parallel_size > 1:
        ranks.append(f"TP={dist.get_rank(parallel_context.tp_pg)}")

    if node_name:
        ranks.append(node_name)

    # Join all ranks with separator
    ranks_str = "|".join(ranks)
    ranks_display = f"|{ranks_str}" if ranks_str else ""

    # Use a custom formatter class that handles missing fields
    class SafeFormatter(Formatter):
        def format(self, record):
            # Ensure required attributes exist before formatting
            if not hasattr(record, "category"):
                record.category = ""
            elif record.category and not record.category.startswith("|"):
                record.category = f"|{record.category}"

            # Store original message for restoration later
            original_msg = record.msg

            # Apply styling based on record properties
            is_separator = getattr(record, "separator", False)
            if is_separator:
                record.msg = f"\033[1m{record.msg}\033[0m"  # Bold for separators

            # Choose color prefix/suffix based on log level
            if record.levelno == logging.WARNING:
                prefix = "\033[1;33m"  # Bold yellow for warnings
            elif record.levelno >= logging.ERROR:
                prefix = "\033[1;31m"  # Bold red for errors and critical
            elif record.levelno == logging.DEBUG:
                prefix = "\033[2;3;32m"  # Dim and italic green for debug
            else:
                prefix = "\033[2;3m"  # Dim and italic for other levels

            suffix = "\033[0m"

            # Save the original format
            original_fmt = self._style._fmt

            # Use a more consistent format with prefix/suffix applied only to the metadata portion
            self._style._fmt = f"{prefix}%(asctime)s [%(levelname)s%(category)s{ranks_display}]{suffix}: %(message)s"

            # Format the record
            result = super().format(record)

            # Restore the original values
            self._style._fmt = original_fmt
            record.msg = original_msg

            return result

    # Create formatter with the safe handling
    formatter = SafeFormatter(
        fmt=f"\033[2;3m%(asctime)s [%(levelname)s%(category)s{ranks_display}]\033[0m: %(message)s",
        datefmt="%m/%d %H:%M:%S",
    )
    log_level = log_levels[logging_level]

    # main root logger
    root_logger = get_logger()
    root_logger.setLevel(log_level)
    handler = NewLineStreamHandler(sys.stdout)
    handler.setLevel(log_level)
    handler.setFormatter(formatter)
    root_logger.addHandler(handler)

    # Nanotron
    set_verbosity(log_level)
    set_formatter(formatter=formatter)