in chatlearn/synchronizer/parameter_sync.py [0:0]
def validate_sync_results(self, send_actor, recv_actors, requires_grad, filter_fn=None, param_group="default"):
assert param_group in ("default", "routed", "except_routed"), (
f"param_group must be one of 'default', 'routed', or 'except_routed', got {param_group}."
)
def validate():
src_names, dst_names = self.set_sync_param_names(send_actor, recv_actors[0], requires_grad, filter_fn, param_group)
# check the value of src model and tgt model
pipe_stage = self.get_actor_pipe_rank(send_actor)
res = [send_actor.reset_sync_parameters.remote(src_names, pipe_stage)]
for recv_actor in recv_actors:
res.append(recv_actor.reset_sync_parameters.remote(dst_names, pipe_stage))
future.wait(res)
src_names, dst_names = future.get([send_actor.get_parameter_to_sync_names.remote(pipe_stage),
recv_actors[0].get_parameter_to_sync_names.remote(pipe_stage)])
assert len(src_names) == len(dst_names), (
f"expect the length of src_names and dst_names being the same, got {len(src_names)} and {len(dst_names)}"
)
# check the value of src model and tgt model
names = list(zip(src_names, dst_names))
for src_name, dst_name in tqdm(names):
if param_group in ("default", "except_routed"):
src_tensor = future.get(send_actor.get_parameter_to_sync.remote(src_name, pipe_stage, True, self.tp_num_mapping > 1))
elif param_group == "routed":
src_tensor = future.get(send_actor.get_parameter_to_sync.remote(src_name, pipe_stage, True))
if src_tensor.isnan().any():
raise RuntimeError(f"weight {src_name} from send actor is nan, please check checkpoint or training process.")
src_tensor_shape = src_tensor.shape
for recv_actor in recv_actors:
dst_tensor = future.get(recv_actor.get_parameter_to_sync.remote(dst_name, pipe_stage, True))
if dst_tensor.isnan().any():
raise RuntimeError(f"weight {dst_name} in recv actor is nan, please check param sync.")
if param_group in ("default", "except_routed"):
if self.tp_num_mapping == 1:
# for trainer_tp == inference_tp
assert src_tensor.shape == dst_tensor.shape, (
f"after weight sync {src_name}: {src_tensor.shape} and {dst_name}: {dst_tensor.shape} do not match."
)
assert torch.allclose(src_tensor, dst_tensor, atol=1e-06), (
f"after weight sync {src_name}: {src_tensor} and {dst_name}: {dst_tensor} do not match."
)
else:
# for inference_tp % trainer_tp == 0 and inference_tp > trainer_tp
dst_tensor_shape = dst_tensor.shape
src_tensor = src_tensor.reshape(-1)
dst_tensor = dst_tensor.reshape(-1)
tp_slice = self.actor2rank[recv_actor] % self.tp_num_mapping
if src_tensor.shape == dst_tensor.shape:
src_tensor_slice = src_tensor
else:
assert (
src_tensor.shape[0] % dst_tensor.shape[0] == 0 and
src_tensor.shape[0] // dst_tensor.shape[0] == self.tp_num_mapping
), (
f"num of elements in src_tensor must be divided by that of dst_tensor. "
f"while src {src_name}: {src_tensor_shape} and dst {dst_name}: {dst_tensor_shape}."
)
start = dst_tensor.shape[0] * tp_slice
end = start + dst_tensor.shape[0]
src_tensor_slice = src_tensor[start:end]
assert torch.allclose(src_tensor_slice, dst_tensor, atol=1e-06), (
f"after weight sync {src_name}_{tp_slice}: "
f"{src_tensor_slice.view(dst_tensor_shape)} and {dst_name}: {dst_tensor.view(dst_tensor_shape)} do not match."
)
elif param_group == "routed":
assert self.hep_num_mapping == 1
assert src_tensor.shape == dst_tensor.shape, (
f"after weight sync {src_name}: {src_tensor.shape} and {dst_name}: {dst_tensor.shape} do not match."
)
assert torch.allclose(src_tensor, dst_tensor, atol=1e-06), (
f"after weight sync {src_name}: {src_tensor} and {dst_name}: {dst_tensor} do not match."
)
return True
logger.info("Going to validate transmitted tensors...")
validate()
logger.info("Validation passed!")