def validate_sync_results()

in chatlearn/synchronizer/parameter_sync.py [0:0]


    def validate_sync_results(self, send_actor, recv_actors, requires_grad, filter_fn=None, param_group="default"):
        assert param_group in ("default", "routed", "except_routed"), (
            f"param_group must be one of 'default', 'routed', or 'except_routed', got {param_group}."
        )

        def validate():
            src_names, dst_names = self.set_sync_param_names(send_actor, recv_actors[0], requires_grad, filter_fn, param_group)
            # check the value of src model and tgt model
            pipe_stage = self.get_actor_pipe_rank(send_actor)
            res = [send_actor.reset_sync_parameters.remote(src_names, pipe_stage)]
            for recv_actor in recv_actors:
                res.append(recv_actor.reset_sync_parameters.remote(dst_names, pipe_stage))
            future.wait(res)

            src_names, dst_names = future.get([send_actor.get_parameter_to_sync_names.remote(pipe_stage),
                                               recv_actors[0].get_parameter_to_sync_names.remote(pipe_stage)])

            assert len(src_names) == len(dst_names), (
                f"expect the length of src_names and dst_names being the same, got {len(src_names)} and {len(dst_names)}"
            )

            # check the value of src model and tgt model
            names = list(zip(src_names, dst_names))
            for src_name, dst_name in tqdm(names):
                if param_group in ("default", "except_routed"):
                    src_tensor = future.get(send_actor.get_parameter_to_sync.remote(src_name, pipe_stage, True, self.tp_num_mapping > 1))
                elif param_group == "routed":
                    src_tensor = future.get(send_actor.get_parameter_to_sync.remote(src_name, pipe_stage, True))
                if src_tensor.isnan().any():
                    raise RuntimeError(f"weight {src_name} from send actor is nan, please check checkpoint or training process.")
                src_tensor_shape = src_tensor.shape
                for recv_actor in recv_actors:
                    dst_tensor = future.get(recv_actor.get_parameter_to_sync.remote(dst_name, pipe_stage, True))
                    if dst_tensor.isnan().any():
                        raise RuntimeError(f"weight {dst_name} in recv actor is nan, please check param sync.")
                    if param_group in ("default", "except_routed"):
                        if self.tp_num_mapping == 1:
                            # for trainer_tp == inference_tp
                            assert src_tensor.shape == dst_tensor.shape, (
                                f"after weight sync {src_name}: {src_tensor.shape} and {dst_name}: {dst_tensor.shape} do not match."
                            )
                            assert torch.allclose(src_tensor, dst_tensor, atol=1e-06), (
                                f"after weight sync {src_name}: {src_tensor} and {dst_name}: {dst_tensor} do not match."
                            )
                        else:
                            # for inference_tp % trainer_tp == 0 and inference_tp > trainer_tp
                            dst_tensor_shape = dst_tensor.shape
                            src_tensor = src_tensor.reshape(-1)
                            dst_tensor = dst_tensor.reshape(-1)
                            tp_slice = self.actor2rank[recv_actor] % self.tp_num_mapping
                            if src_tensor.shape == dst_tensor.shape:
                                src_tensor_slice = src_tensor
                            else:
                                assert (
                                    src_tensor.shape[0] % dst_tensor.shape[0] == 0 and
                                    src_tensor.shape[0] // dst_tensor.shape[0] == self.tp_num_mapping
                                ), (
                                    f"num of elements in src_tensor must be divided by that of dst_tensor. "
                                    f"while src {src_name}: {src_tensor_shape} and dst {dst_name}: {dst_tensor_shape}."
                                )
                                start = dst_tensor.shape[0] * tp_slice
                                end = start + dst_tensor.shape[0]
                                src_tensor_slice = src_tensor[start:end]
                            assert torch.allclose(src_tensor_slice, dst_tensor, atol=1e-06), (
                                f"after weight sync {src_name}_{tp_slice}: "
                                f"{src_tensor_slice.view(dst_tensor_shape)} and {dst_name}: {dst_tensor.view(dst_tensor_shape)} do not match."
                            )
                    elif param_group == "routed":
                        assert self.hep_num_mapping == 1
                        assert src_tensor.shape == dst_tensor.shape, (
                            f"after weight sync {src_name}: {src_tensor.shape} and {dst_name}: {dst_tensor.shape} do not match."
                        )
                        assert torch.allclose(src_tensor, dst_tensor, atol=1e-06), (
                            f"after weight sync {src_name}: {src_tensor} and {dst_name}: {dst_tensor} do not match."
                        )
            return True
        logger.info("Going to validate transmitted tensors...")
        validate()
        logger.info("Validation passed!")