in generate/run_ioi_slurm.py [0:0]
def get_tp(model_name: str, revision: str) -> int:
default_tp = MODEL_CONFIGS.get(model_name, {}).get("tp", DEFAULT_TP)
try:
config = AutoConfig.from_pretrained(model_name, revision=revision, trust_remote_code=True)
# Check num_attention_heads and num_key_value_heads, and ensure that both are divisable by tp
if hasattr(config, 'num_attention_heads'):
if config.num_attention_heads % default_tp != 0:
# Adjust tp to be the highest number that divides both num_attention_heads
new_tp = gcd(config.num_attention_heads, default_tp)
print(f"Adjusted tp for {model_name} from {default_tp} to {new_tp}")
return new_tp
return default_tp
except Exception as e:
print(f"Could not get tp from config for {model_name}: {e}")
return default_tp