def load_prm()

in src/sal/models/reward_models.py [0:0]


def load_prm(config: Config) -> PRM:
    if config.prm_path == "peiyi9979/math-shepherd-mistral-7b-prm":
        return MathShepherd(config)

    if config.prm_path == "RLHFlow/Llama3.1-8B-PRM-Deepseek-Data":
        return RLHFFlow(config)

    if config.prm_path == "Skywork/Skywork-o1-Open-PRM-Qwen-2.5-1.5B":
        return SkyworkO1_1_5B(config)

    if config.prm_path == "Skywork/Skywork-o1-Open-PRM-Qwen-2.5-7B":
        return SkyworkO1_7B(config)

    if config.prm_path == "Qwen/Qwen2.5-Math-PRM-7B":
        return Qwen_2_5_Math_7B(config)

    raise NotImplementedError(f"PRM {config.prm_path} not implemented")