def get_context_length()

in generate/run_ioi_slurm.py [0:0]


def get_context_length(model_name: str, revision: str) -> int:
    """Get maximum context length from model config."""
    try:
        config = AutoConfig.from_pretrained(model_name, revision=revision, trust_remote_code=True)
        # Check various possible context length attributes
        context_length = (
            getattr(config, 'max_position_embeddings', None) or
            getattr(config, 'sliding_window', None) or
            getattr(config, 'max_sequence_length', None) or
            getattr(config, 'max_seq_len', None) or
            4096  # Default fallback
        )

        # Some models (like Qwen) might have sliding_window disabled
        if hasattr(config, 'use_sliding_window') and not config.use_sliding_window:
            # If sliding window is disabled, use max_position_embeddings instead
            context_length = getattr(config, 'max_position_embeddings', context_length)
            

        # cap to 64k
        if MAX_CTX_LENGTH is not None:
            context_length = min(context_length, MAX_CTX_LENGTH)
        return context_length
    except Exception as e:
        logger.warning(f"Could not get context length from config for {model_name}: {e}")
        return 4096  # Default fallback