def get_param_count_from_repo_id()

in src/open_r1/utils/hub.py [0:0]


def get_param_count_from_repo_id(repo_id: str) -> int:
    """Function to get model param counts from safetensors metadata or find patterns like 42m, 1.5b, 0.5m or products like 8x7b in a repo ID."""
    try:
        metadata = get_safetensors_metadata(repo_id)
        return list(metadata.parameter_count.values())[0]
    except Exception:
        # Pattern to match products (like 8x7b) and single values (like 42m)
        pattern = r"((\d+(\.\d+)?)(x(\d+(\.\d+)?))?)([bm])"
        matches = re.findall(pattern, repo_id.lower())

        param_counts = []
        for full_match, number1, _, _, number2, _, unit in matches:
            if number2:  # If there's a second number, it's a product
                number = float(number1) * float(number2)
            else:  # Otherwise, it's a single value
                number = float(number1)

            if unit == "b":
                number *= 1_000_000_000  # Convert to billion
            elif unit == "m":
                number *= 1_000_000  # Convert to million

            param_counts.append(number)

        if len(param_counts) > 0:
            # Return the largest number
            return int(max(param_counts))
        else:
            # Return -1 if no match found
            return -1