chatlearn/synchronizer/parameter_sync_fsdp.py (47 lines of code) (raw):

"""fsdp to vllm parameter sync group""" import ray from chatlearn.utils import future from chatlearn.runtime.dist_actor import DistModel from chatlearn.utils.error_monitor import ErrorSignalActor def flatten(lst: list, reverse=False): result = [] for item in lst: if reverse: result += item[::-1] else: result += item return result class FSDP2VllmParameterSyncGroup: """fsdp to vllm parameter sync group """ def __init__( self, src_model: DistModel, dst_model: DistModel, group_name: str, frequency: int, error_signal: ErrorSignalActor, ): self.src_model = src_model self.dst_model = dst_model self.group_name = group_name self.error_signal = error_signal self.frequency = frequency self.setup_collective_group() def setup_collective_group(self): # we put src_model first, so we don't need to change the rank of training model models = [self.src_model, self.dst_model] rank_offset = 0 for model in models: for replica in model.replicas: replica._setup_ranks(rank_offset) rank_offset += replica.actor_num def sync(self, *args, **kwargs): # pylint: disable=unused-argument """ sync function for fsdp to vllm """ # for fsdp to vllm, we only need to find the src and dst actors that are on the same GPU. src_model_ranks = flatten(self.src_model.all_ranks) # adapt for model manager: models_to_revert dst_model_ranks = flatten(self.dst_model.all_ranks, reverse=True) param_name_list = ray.get(self.src_model.get_actor(0).get_fsdp_param_name.remote()) for param_name in param_name_list: refs = [] for src_rank, dst_rank in zip(src_model_ranks, dst_model_ranks): src_actor = self.src_model.get_actor(src_rank) dst_actor = self.dst_model.get_actor(dst_rank) reduce_data_ref = src_actor.get_weight_ipc_handles_by_name.remote(param_name) ref = dst_actor.update_weights_from_ipc_handles.remote(reduce_data_ref) refs.append(ref) future.wait(refs, return_output=True)