in trl/trainer/grpo_trainer.py [0:0]
def _move_model_to_vllm(self):
# For DeepSpeed ZeRO-3 and FSDP, we need to gather all parameters before operations
deepspeed_plugin = self.accelerator.state.deepspeed_plugin
zero_stage_3 = deepspeed_plugin is not None and deepspeed_plugin.zero_stage == 3
if zero_stage_3:
import deepspeed
gather_if_zero3 = deepspeed.zero.GatheredParameters
else:
gather_if_zero3 = nullcontext
if is_peft_model(self.model):
# With PEFT and FSDP/DeepSpeed ZeRO Stage 3, we must gather the full model at once before merging, as
# merging adapters in a sharded manner is not supported.
# TODO: does this work with FSDP?
with gather_if_zero3(list(self.model.parameters())):
self.model.merge_adapter()
# Update vLLM weights while parameters are gathered
if self.is_fsdp_enabled: # note if using FSDP, gather_if_zero3 is nullcontext
# Update vLLM weights while parameters are gathered
# For PEFT with FSDP we need to use the memory efficient post-order traversal
self._sync_fsdp_params_to_vllm(self.model)
else:
# DeepSpeed ZeRO-3 with PEFT
for name, param in self.model.named_parameters():
# When using PEFT, we need to recover the original parameter name and discard some parameters
name = name.removeprefix("base_model.model.").replace(".base_layer", "")
if self.model.prefix in name:
continue
# When module to save, remove its prefix and discard the original module
if "original_module" in name:
continue
name = name.replace("modules_to_save.default.", "")
if self.vllm_mode == "server" and self.accelerator.is_main_process:
self.vllm_client.update_named_param(name, param.data)
elif self.vllm_mode == "colocate":
llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model
llm_model.load_weights([(name, param.data)])
# Unmerge adapters while parameters are still gathered
self.model.unmerge_adapter()
# Parameters will automatically be repartitioned when exiting the context
else:
# For non-PEFT models, simply gather (if needed) and update each parameter individually.
if self.is_fsdp_enabled:
self._sync_fsdp_params_to_vllm(self.model) # use memory-efficient post-order traversal for FSDP
else:
for name, param in self.model.named_parameters():
with gather_if_zero3([param]):
if self.vllm_mode == "server" and self.accelerator.is_main_process:
self.vllm_client.update_named_param(name, param.data)
elif self.vllm_mode == "colocate":
llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model
llm_model.load_weights([(name, param.data)])
# Reset cache on vLLM
if self.vllm_mode == "server" and self.accelerator.is_main_process:
self.vllm_client.reset_prefix_cache()
elif self.vllm_mode == "colocate":
self.llm.reset_prefix_cache()