internal/gitaly/service/raft/send_snapshot.go (59 lines of code) (raw):
package raft
import (
"errors"
"fmt"
"io"
"os"
"path/filepath"
"gitlab.com/gitlab-org/gitaly/v16/internal/safe"
"gitlab.com/gitlab-org/gitaly/v16/internal/structerr"
"gitlab.com/gitlab-org/gitaly/v16/proto/go/gitalypb"
"gitlab.com/gitlab-org/gitaly/v16/streamio"
)
// SendSnapshot streams a snapshot from a leader to a follower node in the raft network.
func (s *Server) SendSnapshot(stream gitalypb.RaftService_SendSnapshotServer) (returnErr error) {
// get the first message which is the raftMsg
clientMsg, err := stream.Recv()
if err != nil {
return structerr.NewInternal("receive error: %w", err)
}
raftMsg := clientMsg.GetRaftMsg()
_, partitionKey, err := extractRaftMessageReq(raftMsg, s)
if err != nil {
return err
}
fname := fmt.Sprintf("%016d-%016d-%016d%s", partitionKey.GetPartitionId(), raftMsg.GetMessage().Term, raftMsg.GetMessage().Index, ".snap")
snapshotPath := filepath.Join(s.cfg.Raft.SnapshotDir, fname)
snapshotFile, err := os.Create(snapshotPath)
if err != nil {
return fmt.Errorf("create snapshot file: %w", err)
}
defer func() {
// If there are errors, remove file as cleanup
if returnErr != nil {
returnErr = errors.Join(returnErr, os.Remove(snapshotFile.Name()))
}
}()
// Receive a message from the client
sr := streamio.NewReader(func() ([]byte, error) {
clientMsg, err := stream.Recv()
if err != nil {
return nil, err
}
return clientMsg.GetChunk(), nil
})
snapshotSize, err := io.Copy(snapshotFile, sr)
if err != nil {
return structerr.NewInternal("write error: %w", err)
}
// Close file before syncing it to flush all remaining write buffers
if err := snapshotFile.Close(); err != nil {
return fmt.Errorf("close snapshot file %q: %w", snapshotPath, err)
}
syncer := safe.NewSyncer()
if err := syncer.Sync(stream.Context(), snapshotFile.Name()); err != nil {
return fmt.Errorf("sync snapshot file: %w", err)
}
// Received all snapshot chunks, save it locally.
if err := stream.SendAndClose(&gitalypb.RaftSnapshotMessageResponse{
Destination: snapshotFile.Name(),
SnapshotSize: uint64(snapshotSize),
}); err != nil {
return fmt.Errorf("failed to send server message: %w", err)
}
return nil
}