def get_tp()

in generate/run_ioi_slurm.py [0:0]


def get_tp(model_name: str, revision: str) -> int:
    default_tp = MODEL_CONFIGS.get(model_name, {}).get("tp", DEFAULT_TP)
    try:
        config = AutoConfig.from_pretrained(model_name, revision=revision, trust_remote_code=True)

        # Check num_attention_heads and num_key_value_heads, and ensure that both are divisable by tp
        if hasattr(config, 'num_attention_heads'):
            if config.num_attention_heads % default_tp != 0:
                # Adjust tp to be the highest number that divides both num_attention_heads
                new_tp = gcd(config.num_attention_heads, default_tp)
                print(f"Adjusted tp for {model_name} from {default_tp} to {new_tp}")
                return new_tp
        return default_tp
    except Exception as e:
        print(f"Could not get tp from config for {model_name}: {e}")
        return default_tp