internal/gitaly/storage/raftmgr/grpc_transport.go (301 lines of code) (raw):
package raftmgr
import (
"archive/tar"
"bytes"
"context"
"errors"
"fmt"
"io"
"os"
"path/filepath"
"sync"
"gitlab.com/gitlab-org/gitaly/v16/internal/archive"
"gitlab.com/gitlab-org/gitaly/v16/internal/gitaly/config"
"gitlab.com/gitlab-org/gitaly/v16/internal/gitaly/storage"
"gitlab.com/gitlab-org/gitaly/v16/internal/gitaly/storage/mode"
"gitlab.com/gitlab-org/gitaly/v16/internal/grpc/client"
"gitlab.com/gitlab-org/gitaly/v16/internal/log"
"gitlab.com/gitlab-org/gitaly/v16/proto/go/gitalypb"
"gitlab.com/gitlab-org/gitaly/v16/streamio"
"go.etcd.io/raft/v3/raftpb"
"golang.org/x/sync/errgroup"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/proto"
)
// Transport defines the interface for sending Raft protocol messages.
type Transport interface {
// Send dispatches a batch of Raft messages. It returns an error if the sending fails. This function receives a
// context, the list of messages to send and a function that returns the path of WAL directory of a particular
// log entry. The implementation must respect input context's cancellation.
Send(ctx context.Context, logReader storage.LogReader, partitionKey *gitalypb.PartitionKey, messages []raftpb.Message) error
// Receive receives a Raft message and processes it.
Receive(ctx context.Context, partitionKey *gitalypb.PartitionKey, raftMsg raftpb.Message) error
SendSnapshot(ctx context.Context, partitionKey *gitalypb.PartitionKey, message raftpb.Message, snapshot *ReplicaSnapshot) error
}
// GrpcTransport is a gRPC transport implementation for sending Raft messages across nodes.
type GrpcTransport struct {
logger log.Logger
cfg config.Cfg
routingTable RoutingTable
registry ReplicaRegistry
connectionPool *client.Pool
mutex sync.Mutex
}
// NewGrpcTransport creates a new GrpcTransport instance.
func NewGrpcTransport(logger log.Logger, cfg config.Cfg, routingTable RoutingTable, registry ReplicaRegistry, conns *client.Pool) *GrpcTransport {
return &GrpcTransport{
logger: logger,
cfg: cfg,
routingTable: routingTable,
registry: registry,
connectionPool: conns,
}
}
// Send sends Raft messages to the appropriate nodes.
func (t *GrpcTransport) Send(ctx context.Context, logReader storage.LogReader, partitionKey *gitalypb.PartitionKey, messages []raftpb.Message) error {
messagesByNode, err := t.prepareRaftMessageRequests(ctx, logReader, partitionKey, messages)
if err != nil {
return fmt.Errorf("preparing raft messages: %w", err)
}
g := &errgroup.Group{}
errCh := make(chan error, len(messagesByNode))
for addr, reqs := range messagesByNode {
g.Go(func() error {
memberID := reqs[0].GetReplicaId().GetMemberId()
if err := t.sendToNode(ctx, addr, reqs); err != nil {
errCh <- fmt.Errorf("node %d: %w", memberID, err)
return err
}
return nil
})
}
_ = g.Wait() // we are collecting errors in the errCh
close(errCh)
var errs []error
for err := range errCh {
errs = append(errs, err)
}
if len(errs) > 0 {
return errors.Join(errs...)
}
return nil
}
func (t *GrpcTransport) prepareRaftMessageRequests(ctx context.Context, logReader storage.LogReader, partitionKey *gitalypb.PartitionKey, msgs []raftpb.Message) (map[string][]*gitalypb.RaftMessageRequest, error) {
messagesByAddress := make(map[string][]*gitalypb.RaftMessageRequest)
messagesByAddressMutex := sync.Mutex{}
g := &errgroup.Group{}
for _, msg := range msgs {
g.Go(func() error {
for j := range msg.Entries {
if msg.Entries[j].Type != raftpb.EntryNormal {
continue
}
var raftMsg gitalypb.RaftEntry
t.mutex.Lock()
err := proto.Unmarshal(msg.Entries[j].Data, &raftMsg)
t.mutex.Unlock()
if err != nil {
return fmt.Errorf("unmarshalling entry type: %w", err)
}
if raftMsg.GetData().GetPacked() == nil {
lsn := storage.LSN(msg.Entries[j].Index)
path := logReader.GetEntryPath(lsn)
if err := t.packLogData(ctx, lsn, &raftMsg, path); err != nil {
return fmt.Errorf("packing log data: %w", err)
}
}
data, err := proto.Marshal(&raftMsg)
if err != nil {
return fmt.Errorf("marshal entry: %w", err)
}
t.mutex.Lock()
msg.Entries[j].Data = data
t.mutex.Unlock()
}
replica, err := t.routingTable.Translate(partitionKey, msg.To)
if err != nil {
return fmt.Errorf("translate memberID %d: %w", msg.To, err)
}
addr := replica.GetMetadata().GetAddress()
messagesByAddressMutex.Lock()
// We are not adding address in the request because it will increase the payload size, and
// is not needed on the receiver end.
messagesByAddress[addr] = append(messagesByAddress[addr], &gitalypb.RaftMessageRequest{
ClusterId: t.cfg.Raft.ClusterID,
ReplicaId: &gitalypb.ReplicaID{
PartitionKey: partitionKey,
MemberId: msg.To,
StorageName: replica.GetStorageName(),
},
Message: &msg,
})
messagesByAddressMutex.Unlock()
return nil
})
}
err := g.Wait()
if err != nil {
return nil, err
}
return messagesByAddress, nil
}
func (t *GrpcTransport) sendToNode(ctx context.Context, addr string, reqs []*gitalypb.RaftMessageRequest) error {
// get the connection to the node
conn, err := t.connectionPool.Dial(ctx, addr, t.cfg.Auth.Token)
if err != nil {
return fmt.Errorf("get connection to address %s: %w", addr, err)
}
client := gitalypb.NewRaftServiceClient(conn)
stream, err := client.SendMessage(ctx)
if err != nil {
return fmt.Errorf("create stream to address %s: %w", addr, err)
}
for _, req := range reqs {
if err := stream.Send(req); err != nil {
return fmt.Errorf("send request to address %s: %w", addr, err)
}
}
if _, err := stream.CloseAndRecv(); err != nil {
return fmt.Errorf("close stream to address %s: %w", addr, err)
}
return nil
}
func (t *GrpcTransport) packLogData(ctx context.Context, lsn storage.LSN, message *gitalypb.RaftEntry, logEntryPath string) error {
var logData bytes.Buffer
if err := archive.WriteTarball(ctx, t.logger.WithFields(log.Fields{
"raft.component": "WAL archiver",
"raft.log_entry_lsn": lsn,
"raft.log_entry_path": logEntryPath,
}), &logData, logEntryPath, "."); err != nil {
return fmt.Errorf("archiving WAL log entry: %w", err)
}
message.Data = &gitalypb.RaftEntry_LogData{
LocalPath: []byte(logEntryPath),
Packed: logData.Bytes(),
}
return nil
}
// Receive receives a stream of Raft messages and processes them.
func (t *GrpcTransport) Receive(ctx context.Context, partitionKey *gitalypb.PartitionKey, raftMsg raftpb.Message) error {
// Retrieve the replica from the registry, assumption is that all the messages are from the same partition key.
replica, err := t.registry.GetReplica(partitionKey)
if err != nil {
return status.Errorf(codes.NotFound, "replica not found for partition %d: %v",
partitionKey.GetPartitionId(), err)
}
for _, entry := range raftMsg.Entries {
var msg gitalypb.RaftEntry
if err := proto.Unmarshal(entry.Data, &msg); err != nil {
return status.Errorf(codes.InvalidArgument, "failed to unmarshal message: %v", err)
}
if msg.GetData().GetPacked() != nil {
if err := unpackLogData(&msg, replica.GetEntryPath(storage.LSN(entry.Index))); err != nil {
return status.Errorf(codes.Internal, "failed to unpack log data: %v", err)
}
}
}
// Step messages per partition with their respective entries
if err := replica.Step(ctx, raftMsg); err != nil {
return status.Errorf(codes.Internal, "failed to step message: %v", err)
}
return nil
}
func unpackLogData(msg *gitalypb.RaftEntry, logEntryPath string) error {
logData := msg.GetData().GetPacked()
if err := os.MkdirAll(filepath.Dir(logEntryPath), mode.Directory); err != nil {
return fmt.Errorf("creating WAL directory: %w", err)
}
tarReader := tar.NewReader(bytes.NewReader(logData))
for {
header, err := tarReader.Next()
if errors.Is(err, io.EOF) {
break
}
actualName := header.Name
switch header.Typeflag {
case tar.TypeDir:
// create the directory if not exists
if _, err := os.Stat(filepath.Join(logEntryPath, actualName)); os.IsNotExist(err) {
if err := os.Mkdir(filepath.Join(logEntryPath, actualName), mode.Directory); err != nil {
return fmt.Errorf("creating directory: %w", err)
}
}
case tar.TypeReg:
if err := func() error {
path := filepath.Join(logEntryPath, actualName)
f, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, mode.File)
if err != nil {
return fmt.Errorf("writing log entry file: %w", err)
}
defer f.Close()
if _, err := io.Copy(f, tarReader); err != nil {
return fmt.Errorf("writing log entry file: %w", err)
}
return nil
}(); err != nil {
return err
}
}
}
return nil
}
// SendSnapshot sends a snapshot of a partition to a specified node in the cluster.
func (t *GrpcTransport) SendSnapshot(ctx context.Context, pk *gitalypb.PartitionKey, message raftpb.Message, snapshot *ReplicaSnapshot) (returnedErr error) {
followerMemberID := message.To
// Find replica's address as recipient of snapshot
replica, err := t.routingTable.Translate(pk, followerMemberID)
if err != nil {
return fmt.Errorf("translate memberID %d: %w", followerMemberID, err)
}
addr := replica.GetMetadata().GetAddress()
// Get raft client of follower node
client, returnedErr := t.getRaftClient(ctx, addr)
if returnedErr != nil {
return returnedErr
}
// Create a client stream
stream, err := client.SendSnapshot(ctx)
if err != nil {
return fmt.Errorf("failed to create stream: %w", err)
}
// Ensure stream is closed properly
defer func() {
if _, err := stream.CloseAndRecv(); err != nil {
returnedErr = errors.Join(returnedErr, formatError(err, followerMemberID, "close stream"))
}
}()
if err := stream.Send(&gitalypb.RaftSnapshotMessageRequest{
RaftSnapshotPayload: &gitalypb.RaftSnapshotMessageRequest_RaftMsg{
RaftMsg: &gitalypb.RaftMessageRequest{
ClusterId: t.cfg.Raft.ClusterID,
ReplicaId: &gitalypb.ReplicaID{
StorageName: replica.GetStorageName(),
PartitionKey: &gitalypb.PartitionKey{
AuthorityName: pk.GetAuthorityName(),
PartitionId: pk.GetPartitionId(),
},
},
Message: &message,
},
},
}); err != nil {
return fmt.Errorf("failed to send raft message: %w", err)
}
// Send snapshot data in chunks to the server
sw := streamio.NewWriter(func(p []byte) error {
select {
case <-stream.Context().Done():
return fmt.Errorf("context cancelled while sending snapshot: %w", ctx.Err())
default:
return stream.Send(&gitalypb.RaftSnapshotMessageRequest{
RaftSnapshotPayload: &gitalypb.RaftSnapshotMessageRequest_Chunk{
Chunk: p,
},
})
}
})
sent, err := io.Copy(sw, snapshot.file)
if err != nil {
return fmt.Errorf("failed to send chunk, %d bytes sent: %w", sent, err)
}
return
}
// getRaftClient returns a Raft client connection for the given address
func (t *GrpcTransport) getRaftClient(ctx context.Context, addr string) (gitalypb.RaftServiceClient, error) {
// get the connection to the node
conn, err := t.connectionPool.Dial(ctx, addr, t.cfg.Auth.Token)
if err != nil {
return nil, fmt.Errorf("get connection to address %s: %w", addr, err)
}
return gitalypb.NewRaftServiceClient(conn), nil
}
// formatError formats gRPC errors with specific messages based on error codes.
// It handles common connection-related errors and uses the provided default message for other errors.
func formatError(err error, memberID uint64, defaultMsg string) error {
switch status.Code(err) {
case codes.Unavailable:
return fmt.Errorf("connection to node %d lost: %w", memberID, err)
case codes.Canceled:
return fmt.Errorf("node %d rejected request: connection canceled: %w", memberID, err)
case codes.Aborted:
return fmt.Errorf("node %d aborted request: %w", memberID, err)
default:
return fmt.Errorf("%s to node %d: %w", defaultMsg, memberID, err)
}
}