generate/utils/get_context_length.py (26 lines of code) (raw):
from transformers import AutoConfig
from typing import Dict, Any
import logging
logger = logging.getLogger(__name__)
def get_context_length(model_name: str) -> int:
"""Get maximum context length from model config."""
try:
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
# Check various possible context length attributes
context_length = (
getattr(config, 'max_position_embeddings', None) or
getattr(config, 'sliding_window', None) or
getattr(config, 'max_sequence_length', None) or
getattr(config, 'max_seq_len', None) or
4096 # Default fallback
)
# Some models (like Qwen) might have sliding_window disabled
if hasattr(config, 'use_sliding_window') and not config.use_sliding_window:
# If sliding window is disabled, use max_position_embeddings instead
context_length = getattr(config, 'max_position_embeddings', context_length)
# Cap to 32k
return min(context_length, 32768)
except Exception as e:
logger.warning(f"Could not get context length from config for {model_name}: {e}")
return 4096 # Default fallback
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--model_name", type=str, required=True)
args = parser.parse_args()
print(get_context_length(args.model_name))