def get_tensor_parallel_size()

in scripts/get_tensor_parallel_size.py [0:0]


def get_tensor_parallel_size(model_name: str, revision: str = None, default_tp: int = 8) -> int:
    try:
        config = AutoConfig.from_pretrained(model_name, revision=revision, trust_remote_code=True)
        num_heads = getattr(config, 'num_attention_heads', None)

        if num_heads is not None and num_heads % default_tp != 0:
            tp = gcd(num_heads, default_tp)
            return max(tp, 1)
        else:
            return default_tp
    except Exception as e:
        print(f"Warning: Failed to fetch config for {model_name}@{revision}: {e}")
        return default_tp