internal/gitaly/service/raft/send_message.go (64 lines of code) (raw):

package raft import ( "errors" "io" "gitlab.com/gitlab-org/gitaly/v16/internal/gitaly/storage/raftmgr" "gitlab.com/gitlab-org/gitaly/v16/internal/structerr" "gitlab.com/gitlab-org/gitaly/v16/proto/go/gitalypb" ) // SendMessage is a gRPC method for sending a Raft message across nodes. func (s *Server) SendMessage(stream gitalypb.RaftService_SendMessageServer) error { for { req, err := stream.Recv() if err != nil { if errors.Is(err, io.EOF) { break } return structerr.NewInternal("receive error: %w", err) } replicaID, partitionKey, err := extractRaftMessageReq(req, s) if err != nil { return err } storageName := replicaID.GetStorageName() node, ok := s.node.(*raftmgr.Node) if !ok { return structerr.NewInternal("node is not Raft-enabled") } storageManager, err := node.GetStorage(storageName) if err != nil { return structerr.NewInternal("get storage manager: %w", err) } raftStorage, ok := storageManager.(*raftmgr.RaftEnabledStorage) if !ok { return structerr.NewInternal("storage is not Raft-enabled") } transport := raftStorage.GetTransport() if transport == nil { return structerr.NewInternal("transport not available") } if err := transport.Receive(stream.Context(), partitionKey, *req.GetMessage()); err != nil { return structerr.NewInternal("receive error: %w", err) } } return stream.SendAndClose(&gitalypb.RaftMessageResponse{}) } func extractRaftMessageReq(req *gitalypb.RaftMessageRequest, s *Server) (*gitalypb.ReplicaID, *gitalypb.PartitionKey, error) { replicaID := req.GetReplicaId() partitionKey := replicaID.GetPartitionKey() authorityName := partitionKey.GetAuthorityName() partitionID := partitionKey.GetPartitionId() // The cluster ID protects Gitaly from cross-cluster interactions, which could potentially corrupt the clusters. // This is particularly crucial after disaster recovery so that an identical cluster is restored from backup. if req.GetClusterId() == "" { return nil, nil, structerr.NewInvalidArgument("cluster_id is required") } // Let's assume we have a single cluster per node for now. if req.GetClusterId() != s.cfg.Raft.ClusterID { return nil, nil, structerr.NewPermissionDenied("message from wrong cluster: got %q, want %q", req.GetClusterId(), s.cfg.Raft.ClusterID) } if authorityName == "" { return nil, nil, structerr.NewInvalidArgument("authority_name is required") } if partitionID == 0 { return nil, nil, structerr.NewInvalidArgument("partition_id is required") } return replicaID, partitionKey, nil }