internal/gitaly/gitaly.go (311 lines of code) (raw):
// Package gitaly provides a client for interacting with GitLab's Gitaly service.
// It implements functionality for retrieving repository data, tracking file changes,
// and efficiently processing Git blobs for indexing purposes. The package handles
// both SHA1 and SHA256 hash formats, manages repository connections via gRPC,
// and provides optimized methods for batch processing of file changes between
// different Git revisions.
package gitaly
import (
"context"
"errors"
"fmt"
"io"
"log/slog"
gitalyauth "gitlab.com/gitlab-org/gitaly/v16/auth"
gitalyclient "gitlab.com/gitlab-org/gitaly/v16/client"
pb "gitlab.com/gitlab-org/gitaly/v16/proto/go/gitalypb"
grpccorrelation "gitlab.com/gitlab-org/labkit/correlation/grpc"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
)
const (
NullTreeSHA = "4b825dc642cb6eb9a060e54bf8d69288fbee4904" // SHA1("tree 0\0")
ZeroSHA = "0000000000000000000000000000000000000000"
NullTreeSHA256 = "6ef19b41225c5369f1c104d45d8d85efa9b057b53b14b4b9b939dd74decc5321" // SHA256("tree 0\0")
ZeroSHA256 = "0000000000000000000000000000000000000000000000000000000000000000"
FormatSha256 = "OBJECT_FORMAT_SHA256"
ClientName = "gitlab-zoekt-indexer"
SubmoduleFileMode = 0160000
IndexBatchSize = 128 // With higher batch size beyond 10k you might get this type of error rpc error: code = Internal desc = processing blobs: rev-list: starting process argument list too long, stderr: \"\""
)
type StorageConfig struct {
Address string `json:"address"`
Token string `json:"token"`
StorageName string `json:"storage"`
RelativePath string `json:"relative_path"`
ProjectPath string `json:"project_path"`
TokenVersion int `json:"token_version"`
}
type GitalyClient struct {
conn *grpc.ClientConn
repository *pb.Repository
blobServiceClient pb.BlobServiceClient
diffServiceClient pb.DiffServiceClient
repositoryServiceClient pb.RepositoryServiceClient
refServiceClient pb.RefServiceClient
commitServiceClient pb.CommitServiceClient
ctx context.Context
FromHash string
ToHash string
limitFileSize int64
}
type File struct {
Path string
Content []byte
Oid string
Size int64
TooLarge bool
}
type PutFunc func(file *File) error
type DelFunc func(path string) error
type HashFinder interface {
getFromHash() string
getDefaultSHAFromHash() (string, error)
}
func NewGitalyClient(ctx context.Context, config *StorageConfig, projectID uint32, limitFileSize int64) (*GitalyClient, error) {
var RPCCred credentials.PerRPCCredentials
if config.TokenVersion == 0 || config.TokenVersion == 2 {
RPCCred = gitalyauth.RPCCredentialsV2(config.Token)
} else {
return nil, errors.New("unknown token version")
}
connOpts := append(
gitalyclient.DefaultDialOpts,
grpc.WithPerRPCCredentials(RPCCred),
grpc.WithStreamInterceptor(
grpccorrelation.StreamClientCorrelationInterceptor(
grpccorrelation.WithClientName(ClientName),
),
),
grpc.WithUnaryInterceptor(
grpccorrelation.UnaryClientCorrelationInterceptor(
grpccorrelation.WithClientName(ClientName),
),
),
)
conn, err := gitalyclient.Dial(config.Address, connOpts)
if err != nil {
return nil, fmt.Errorf("did not connect: %w", err)
}
repository := &pb.Repository{
StorageName: config.StorageName,
RelativePath: config.RelativePath,
GlProjectPath: config.ProjectPath,
GlRepository: fmt.Sprint(projectID),
}
client := &GitalyClient{
conn: conn,
repository: repository,
blobServiceClient: pb.NewBlobServiceClient(conn),
diffServiceClient: pb.NewDiffServiceClient(conn),
repositoryServiceClient: pb.NewRepositoryServiceClient(conn),
refServiceClient: pb.NewRefServiceClient(conn),
commitServiceClient: pb.NewCommitServiceClient(conn),
limitFileSize: limitFileSize,
ctx: ctx,
}
return client, nil
}
func (gc *GitalyClient) Close() {
gc.conn.Close() //nolint:errcheck,gosec
}
func (gc *GitalyClient) IsValidSHA(SHA string) bool {
request := &pb.FindCommitRequest{
Repository: gc.repository,
Revision: []byte(SHA),
}
commit, err := gc.commitServiceClient.FindCommit(gc.ctx, request)
return err == nil && commit.Commit != nil
}
func (gc *GitalyClient) GetCurrentSHA() (string, error) {
repoExistsResponse, err := gc.repositoryServiceClient.RepositoryExists(gc.ctx, &pb.RepositoryExistsRequest{
Repository: gc.repository,
})
if err != nil {
return "", err
}
if repoExistsResponse.Exists {
defaultBranchName, err := gc.findDefaultBranchName()
if err != nil {
return "", err
}
if len(defaultBranchName) == 0 {
return "", nil
}
request := &pb.FindCommitRequest{
Repository: gc.repository,
Revision: defaultBranchName,
}
response, err := gc.commitServiceClient.FindCommit(gc.ctx, request)
if err != nil {
return "", fmt.Errorf("cannot look up HEAD: %w", err)
}
return response.Commit.Id, nil
} else {
return "", nil
}
}
func (gc *GitalyClient) findDefaultBranchName() ([]byte, error) {
request := &pb.FindDefaultBranchNameRequest{
Repository: gc.repository,
}
response, err := gc.refServiceClient.FindDefaultBranchName(gc.ctx, request)
if err != nil {
return nil, fmt.Errorf("cannot find a default branch: %w", err)
}
return response.Name, nil
}
func (gc *GitalyClient) getDefaultSHAFromHash() (string, error) {
request := &pb.ObjectFormatRequest{Repository: gc.repository}
response, err := gc.repositoryServiceClient.ObjectFormat(gc.ctx, request)
if err != nil {
return "", fmt.Errorf("could not call rpc.ObjectFormat: %w", err)
}
if response.Format.String() == FormatSha256 {
return NullTreeSHA256, nil
}
return NullTreeSHA, nil
}
func determineFromHash(hf HashFinder) (string, error) {
hash := hf.getFromHash()
switch hash {
case ZeroSHA:
return NullTreeSHA, nil
case ZeroSHA256:
return NullTreeSHA256, nil
case "":
return hf.getDefaultSHAFromHash()
}
return hash, nil
}
func (gc *GitalyClient) getFromHash() string {
return gc.FromHash
}
func (gc *GitalyClient) EachFileChange(put PutFunc, del DelFunc) error {
var err error
gc.FromHash, err = determineFromHash(gc)
if err != nil {
return fmt.Errorf("determine from hash: %w", err)
}
request := &pb.FindChangedPathsRequest{
Repository: gc.repository,
Requests: []*pb.FindChangedPathsRequest_Request{{
Type: &pb.FindChangedPathsRequest_Request_TreeRequest_{
TreeRequest: &pb.FindChangedPathsRequest_Request_TreeRequest{
LeftTreeRevision: gc.FromHash,
RightTreeRevision: gc.ToHash,
},
},
}},
}
ctx, cancel := context.WithCancel(gc.ctx)
defer cancel()
stream, err := gc.diffServiceClient.FindChangedPaths(ctx, request)
if err != nil {
return fmt.Errorf("find changed paths: %w", err)
}
pathsByBlobID := map[string][]string{}
for {
c, errFindChangedPathsResp := stream.Recv()
if errFindChangedPathsResp == io.EOF { //nolint:errorlint
break
}
if errFindChangedPathsResp != nil {
return fmt.Errorf("recv: %w", errFindChangedPathsResp)
}
for _, change := range c.Paths {
// We skip submodules from indexing now just to mirror the go-git
// implementation but it can be not that expensive to implement with gitaly actually so some
// investigation is required here
if change.OldMode == SubmoduleFileMode || change.NewMode == SubmoduleFileMode {
continue
}
switch change.GetStatus() {
case pb.ChangedPaths_DELETED:
if err = del(string(change.Path)); err != nil {
return fmt.Errorf("del: %w", err)
}
case pb.ChangedPaths_RENAMED:
if err = del(string(change.OldPath)); err != nil {
return fmt.Errorf("del: %w", err)
}
// Fallthrough to index the blob at its new path.
fallthrough
case pb.ChangedPaths_ADDED, pb.ChangedPaths_MODIFIED, pb.ChangedPaths_COPIED:
pathsByBlobID[change.NewBlobId] = append(pathsByBlobID[change.NewBlobId], string(change.Path))
case pb.ChangedPaths_TYPE_CHANGE:
slog.Warn("status is not supported to perform indexing", "status", change.GetStatus(), "repoId", gc.repository.GlRepository)
default:
slog.Warn("status is not supported to perform indexing", "status", change.GetStatus(), "repoId", gc.repository.GlRepository)
}
}
}
revisions := make([]string, 0, IndexBatchSize)
for blobID := range pathsByBlobID {
revisions = append(revisions, blobID)
if len(revisions) == IndexBatchSize {
err = gc.bulkIndex(ctx, revisions, pathsByBlobID, put)
if err != nil {
return fmt.Errorf("bulkIndex: %w", err)
}
revisions = revisions[:0]
}
}
err = gc.bulkIndex(ctx, revisions, pathsByBlobID, put) // index the last remaining batch
if err != nil {
return fmt.Errorf("bulkIndex: %w", err)
}
return nil
}
func (gc *GitalyClient) bulkIndex(ctx context.Context, revisions []string, pathsByBlobID map[string][]string, put PutFunc) error {
if len(revisions) == 0 {
return nil
}
listBlobsRequest := &pb.ListBlobsRequest{
Repository: gc.repository,
Revisions: revisions,
BytesLimit: gc.limitFileSize,
}
ctx, cancel := context.WithCancel(ctx)
defer cancel()
blobsStream, err := gc.blobServiceClient.ListBlobs(ctx, listBlobsRequest)
if err != nil {
return fmt.Errorf("ListBlobs: %w", err)
}
streamStart := true
var data []byte
var oid string
var size int64
for {
listblobsResponse, err := blobsStream.Recv()
if err == io.EOF { //nolint:errorlint
if err = gc.buildFilesForOid(pathsByBlobID[oid], oid, data, size, put); err != nil {
return err
}
break
}
if err != nil {
return fmt.Errorf("ListBlobs.Recv: %w", err)
}
for _, blob := range listblobsResponse.GetBlobs() {
if !streamStart && blob.Oid != "" {
if err = gc.buildFilesForOid(pathsByBlobID[oid], oid, data, size, put); err != nil {
return err
}
data = nil
}
streamStart = false
data = append(data, blob.Data...)
if blob.Oid != "" {
oid = blob.Oid
size = blob.Size
}
}
}
return nil
}
func (gc *GitalyClient) buildFilesForOid(paths []string, oid string, data []byte, size int64, put PutFunc) error {
for _, path := range paths {
file := &File{
Path: path,
Oid: oid,
Content: data,
Size: size,
TooLarge: size > gc.limitFileSize,
}
if err := put(file); err != nil {
return fmt.Errorf("put: %w", err)
}
}
return nil
}