proxy/proxyserver/prefetch.go (370 lines of code) (raw):

package proxyserver import ( "bytes" "context" "encoding/json" "errors" "fmt" "io" "net/http" "net/url" "strings" "sync" "time" "github.com/docker/distribution" "github.com/docker/distribution/manifest/manifestlist" "github.com/docker/distribution/manifest/schema2" "github.com/uber-go/tally" "github.com/uber/kraken/build-index/tagclient" "github.com/uber/kraken/core" "github.com/uber/kraken/origin/blobclient" "github.com/uber/kraken/utils/httputil" "github.com/uber/kraken/utils/log" "go.uber.org/zap" ) // Constants for prefetch status. const ( StatusSuccess = "success" StatusFailure = "failure" ) // PrefetchHandler handles prefetch requests. type PrefetchHandler struct { clusterClient blobclient.ClusterClient tagClient tagclient.Client tagParser TagParser minBlobSizeBytes int64 // Minimum size in bytes for a blob to be prefetched. 0 means no minimum. maxBlobSizeBytes int64 // Maximum size in bytes for a blob to be prefetched. 0 means no maximum. v1Synchronous bool metrics tally.Scope getManifestLatency tally.Histogram getTagLatency tally.Histogram } // blobInfo holds digest and size information for a blob. type blobInfo struct { digest core.Digest size int64 } // Request and response payloads. type prefetchBody struct { Tag string `json:"tag"` TraceId string `json:"trace_id"` } type prefetchResponse struct { Tag string `json:"tag"` Prefetched bool `json:"prefetched"` Status string `json:"status"` Message string `json:"message"` TraceId string `json:"trace_id"` } type prefetchError struct { Error string `json:"error"` Prefetched bool `json:"prefetched"` Status string `json:"status"` Message string `json:"message"` TraceId string `json:"trace_id,omitempty"` } type TagParser interface { ParseTag(tag string) (namespace, name string, err error) } type DefaultTagParser struct{} // ParseTag implements the TagParser interface. // Expects tag strings in the format <hostname>/<namespace>/<imagename:tag>. func (p *DefaultTagParser) ParseTag(tag string) (namespace, name string, err error) { parts := strings.Split(tag, "/") if len(parts) < 3 { return "", "", fmt.Errorf("invalid tag format: %s", tag) } return parts[1], parts[2], nil } // NewPrefetchHandler constructs a new PrefetchHandler. func NewPrefetchHandler( client blobclient.ClusterClient, tagClient tagclient.Client, tagParser TagParser, metrics tally.Scope, minBlobSizeBytes int64, maxBlobSizeBytes int64, v1Synchronous bool, ) *PrefetchHandler { if tagParser == nil { tagParser = &DefaultTagParser{} } // Apply defaults if not configured if minBlobSizeBytes == 0 { minBlobSizeBytes = int64(DefaultPrefetchMinBlobSize) } if maxBlobSizeBytes == 0 { maxBlobSizeBytes = int64(DefaultPrefetchMaxBlobSize) } m := metrics.SubScope("prefetch") return &PrefetchHandler{ clusterClient: client, tagClient: tagClient, tagParser: tagParser, v1Synchronous: v1Synchronous, minBlobSizeBytes: minBlobSizeBytes, maxBlobSizeBytes: maxBlobSizeBytes, metrics: m, getManifestLatency: m.Histogram("download_manifest_latency", tally.MustMakeExponentialDurationBuckets(1*time.Second, 2, 12)), getTagLatency: m.Histogram("get_tag_latency", tally.MustMakeExponentialDurationBuckets(100*time.Millisecond, 2, 10)), } } // newPrefetchSuccessResponse constructs a successful response. func newPrefetchSuccessResponse(tag, msg, traceId string) *prefetchResponse { return &prefetchResponse{ Tag: tag, Prefetched: true, Status: StatusSuccess, Message: msg, TraceId: traceId, } } // newPrefetchError constructs an error response. func newPrefetchError(status int, msg, traceId string) *prefetchError { return &prefetchError{ Error: http.StatusText(status), Prefetched: false, Status: StatusFailure, Message: msg, TraceId: traceId, } } // writeJSON writes the JSON payload with the given HTTP status. func writeJSON(w http.ResponseWriter, status int, payload interface{}) { w.Header().Set("Content-Type", "application/json") response, err := json.Marshal(payload) if err != nil { log.With("payload", payload).Errorf("Failed to marshal JSON: %v", err) http.Error(w, "Internal Server Error", http.StatusInternalServerError) return } // Write the status code and the encoded JSON response. w.WriteHeader(status) if _, err := w.Write(response); err != nil { log.With("payload", payload).Errorf("Failed to write response: %v", err) } } func writeBadRequestError(w http.ResponseWriter, msg, traceId string) { writeJSON(w, http.StatusBadRequest, newPrefetchError(http.StatusBadRequest, msg, traceId)) } func writeInternalError(w http.ResponseWriter, msg, traceId string) { writeJSON(w, http.StatusInternalServerError, newPrefetchError(http.StatusInternalServerError, msg, traceId)) } func writePrefetchResponse(w http.ResponseWriter, tag, msg, traceId string) { writeJSON(w, http.StatusOK, newPrefetchSuccessResponse(tag, msg, traceId)) } // HandleV1 processes the prefetch request. func (ph *PrefetchHandler) HandleV1(w http.ResponseWriter, r *http.Request) { input, errOccurred := ph.preparePrefetch(w, r) if errOccurred { return } ph.metrics.Counter("initiated").Inc(1) writePrefetchResponse(w, input.tag, "prefetching initiated successfully", input.traceID) if ph.v1Synchronous { ph.downloadBlobs(input) } else { // Download blobs asynchronously. go ph.downloadBlobs(input) } } type prefetchInput struct { blobs []blobInfo namespace string logger *zap.SugaredLogger tag string traceID string } // preparePrefetch parses the request, calls build-index to get the image manifest SHA, // downloads the manifest(s) from the origin cluster, parses them, and returns the blobs layers to prefetch. // If an error occurs, preparePrefetch returns the appropriate HTTP response. func (ph *PrefetchHandler) preparePrefetch(w http.ResponseWriter, r *http.Request) (res *prefetchInput, errOccurred bool) { ph.metrics.Counter("requests").Inc(1) var reqBody prefetchBody if err := json.NewDecoder(r.Body).Decode(&reqBody); err != nil { writeBadRequestError(w, fmt.Sprintf("failed to decode request body: %s", err), "") log.With("error", err).Error("Failed to decode request body") return nil, true } logger := log. With("trace_id", reqBody.TraceId). With("image_tag", reqBody.Tag) namespace, tag, err := ph.tagParser.ParseTag(reqBody.Tag) if err != nil { writeBadRequestError(w, fmt.Sprintf("tag: %s, invalid tag format: %s", reqBody.Tag, err), reqBody.TraceId) return nil, true } tagRequest := url.QueryEscape(fmt.Sprintf("%s/%s", namespace, tag)) startTime := time.Now() digest, err := ph.tagClient.Get(tagRequest) if err != nil { ph.metrics.Counter("get_tag_error").Inc(1) logger.With("error", err).Error("Failed to get manifest tag") writeInternalError(w, fmt.Sprintf("tag request: %s, failed to get tag: %s", tagRequest, err), reqBody.TraceId) return nil, true } ph.getTagLatency.RecordDuration(time.Since(startTime)) logger.Infof("Namespace: %s, Tag: %s", namespace, tag) buf := &bytes.Buffer{} startTime = time.Now() if err := ph.clusterClient.DownloadBlob(context.Background(), namespace, digest, buf); err != nil { ph.metrics.Counter("download_manifest_error").Inc(1) logger.With("error", err).Error("Failed to download manifest blob") writeInternalError(w, fmt.Sprintf("error downloading manifest blob: %s", err), reqBody.TraceId) return nil, true } ph.getManifestLatency.RecordDuration(time.Since(startTime)) // Process manifest (ManifestList or single Manifest) blobs, err := ph.processManifest(logger, namespace, buf.Bytes()) if err != nil { writeInternalError(w, fmt.Sprintf("failed to process manifest: %s", err), reqBody.TraceId) return nil, true } return &prefetchInput{ blobs: blobs, namespace: namespace, logger: logger, tag: tag, traceID: reqBody.TraceId, }, false } // downloadBlobs downloads blobs in parallel. func (ph *PrefetchHandler) downloadBlobs(input *prefetchInput) { var wg sync.WaitGroup var mu sync.Mutex var errList []error for _, b := range input.blobs { if ph.shouldSkipPrefetch(b, input.logger) { continue } wg.Add(1) go func(blob blobInfo) { defer wg.Done() blobStart := time.Now() err := ph.clusterClient.DownloadBlob(context.Background(), input.namespace, blob.digest, io.Discard) blobDuration := time.Since(blobStart) ph.metrics.Timer("blob_download_time").Record(blobDuration) ph.metrics.Counter("bytes_downloaded").Inc(blob.size) if err != nil { if serr, ok := err.(httputil.StatusError); ok && serr.Status == http.StatusAccepted { return } mu.Lock() errList = append(errList, fmt.Errorf("digest %s, namespace %s, error downloading blob: %w", blob.digest, input.namespace, err)) mu.Unlock() } else { ph.metrics.Counter("blobs_downloaded").Inc(1) } }(b) } wg.Wait() if len(errList) > 0 { ph.metrics.Counter("failed").Inc(1) for _, err := range errList { input.logger.With("error", err).Error("Error downloading blob") } } } // Skip blobs that are outside the size range [min, max] func (ph *PrefetchHandler) shouldSkipPrefetch(b blobInfo, logger *zap.SugaredLogger) bool { if b.size < ph.minBlobSizeBytes { logger.With( "digest", b.digest, "size", b.size, "min_threshold", ph.minBlobSizeBytes, ).Infof("Skipping blob: size below minimum threshold") ph.metrics.Counter("blobs_skipped_too_small").Inc(1) return true } if b.size > ph.maxBlobSizeBytes { logger.With( "digest", b.digest, "size", b.size, "max_threshold", ph.maxBlobSizeBytes, ).Infof("Skipping blob: size exceeds maximum threshold") ph.metrics.Counter("blobs_skipped_too_large").Inc(1) return true } return false } // processManifest handles both ManifestLists and single Manifests. func (ph *PrefetchHandler) processManifest(logger *zap.SugaredLogger, namespace string, manifestBytes []byte) ([]blobInfo, error) { // Attempt to process as a manifest list. blobs, err := ph.tryProcessManifestList(logger, namespace, manifestBytes) if err == nil && len(blobs) > 0 { return blobs, nil } // Fallback to single manifest. var manifest schema2.Manifest if err := json.NewDecoder(bytes.NewReader(manifestBytes)).Decode(&manifest); err != nil { logger.With("namespace", namespace).Errorf("Failed to parse single manifest: %v", err) return nil, fmt.Errorf("invalid single manifest: %w", err) } return ph.processLayers(manifest.Layers) } // tryProcessManifestList attempts to decode a manifest list. func (ph *PrefetchHandler) tryProcessManifestList(logger *zap.SugaredLogger, namespace string, manifestBytes []byte) ([]blobInfo, error) { var manifestList manifestlist.ManifestList if err := json.NewDecoder(bytes.NewReader(manifestBytes)).Decode(&manifestList); err != nil || len(manifestList.Manifests) == 0 { return nil, fmt.Errorf("not a valid manifest list") } logger.With("namespace", namespace).Info("Processing manifest list") return ph.processManifestList(logger, namespace, manifestList) } // processManifestList processes a manifest list. func (ph *PrefetchHandler) processManifestList(logger *zap.SugaredLogger, namespace string, manifestList manifestlist.ManifestList) ([]blobInfo, error) { var allBlobs []blobInfo for _, descriptor := range manifestList.Manifests { manifestDigestHex := descriptor.Digest.Hex() digest, err := core.NewSHA256DigestFromHex(manifestDigestHex) if err != nil { return nil, fmt.Errorf("failed to parse manifest digest %s: %w", manifestDigestHex, err) } buf := &bytes.Buffer{} startTime := time.Now() if err := ph.clusterClient.DownloadBlob(context.Background(), namespace, digest, buf); err != nil { ph.metrics.Counter("download_manifest_error").Inc(1) logger.With("error", err).Error("Failed to download manifest blob") continue } ph.getManifestLatency.RecordDuration(time.Since(startTime)) var manifest schema2.Manifest if err := json.NewDecoder(buf).Decode(&manifest); err != nil { return nil, fmt.Errorf("failed to parse manifest: %w", err) } blobs, err := ph.processLayers(manifest.Layers) if err != nil { return nil, err } allBlobs = append(allBlobs, blobs...) } return allBlobs, nil } // processLayers converts layer descriptors to a list of blobInfo with size information. func (ph *PrefetchHandler) processLayers(layers []distribution.Descriptor) ([]blobInfo, error) { blobs := make([]blobInfo, 0, len(layers)) for _, layer := range layers { digest, err := core.NewSHA256DigestFromHex(layer.Digest.Hex()) if err != nil { return nil, fmt.Errorf("invalid layer digest: %w", err) } blobs = append(blobs, blobInfo{ digest: digest, size: layer.Size, }) } return blobs, nil } // HandleV2 is a *mostly* idempotent operation that preheats the origin cluster's cache with the provided image. // For each image layer: // - if it is not present, it is prefetched by the origins asynchronously. // - if it is present, no-op. // The operation is "mostly" idempotent, as while it does not cause image layer redownloads, // it ALWAYS 1) calls BI to get the manifest SHA and 2) downloads all image manifests from the origins. func (ph *PrefetchHandler) HandleV2(w http.ResponseWriter, r *http.Request) { input, errOccurred := ph.preparePrefetch(w, r) if errOccurred { return } err := ph.triggerPrefetchBlobs(input) if err != nil { writeInternalError(w, fmt.Sprintf("failed to trigger image prefetch: %s", err), input.traceID) input.logger.Errorf("Failed to trigger image prefetch") return } ph.metrics.Counter("initiated").Inc(1) writePrefetchResponse(w, input.tag, "prefetching initiated successfully", input.traceID) } // triggerPrefetchBlobs triggers a blob prefetch for all blobs in parallel. func (ph *PrefetchHandler) triggerPrefetchBlobs(input *prefetchInput) error { var wg sync.WaitGroup var mu sync.Mutex var errList []error for _, b := range input.blobs { if ph.shouldSkipPrefetch(b, input.logger) { continue } wg.Add(1) go func(digest core.Digest) { defer wg.Done() err := ph.clusterClient.PrefetchBlob(input.namespace, digest) if err != nil { mu.Lock() errList = append(errList, fmt.Errorf("digest %q, namespace %q, blob prefetch failure: %w", digest, input.namespace, err)) mu.Unlock() } }(b.digest) } wg.Wait() if len(errList) != 0 { return fmt.Errorf("at least one layer could not be prefetched: %w", errors.Join(errList...)) } return nil }