proxy/proxyserver/prefetch.go (243 lines of code) (raw):
package proxyserver
import (
"bytes"
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"net/url"
"strings"
"sync"
"time"
"github.com/docker/distribution"
"github.com/docker/distribution/manifest/manifestlist"
"github.com/docker/distribution/manifest/schema2"
"github.com/uber-go/tally"
"github.com/uber/kraken/build-index/tagclient"
"github.com/uber/kraken/core"
"github.com/uber/kraken/origin/blobclient"
"github.com/uber/kraken/utils/httputil"
"github.com/uber/kraken/utils/log"
"go.uber.org/zap"
)
// Constants for prefetch status.
const (
StatusSuccess = "success"
StatusFailure = "failure"
)
// PrefetchHandler handles prefetch requests.
type PrefetchHandler struct {
clusterClient blobclient.ClusterClient
tagClient tagclient.Client
tagParser TagParser
metrics tally.Scope
}
// Request and response payloads.
type prefetchBody struct {
Tag string `json:"tag"`
TraceId string `json:"trace_id"`
}
type prefetchResponse struct {
Tag string `json:"tag"`
Prefetched bool `json:"prefetched"`
Status string `json:"status"`
Message string `json:"message"`
TraceId string `json:"trace_id"`
}
type prefetchError struct {
Error string `json:"error"`
Prefetched bool `json:"prefetched"`
Status string `json:"status"`
Message string `json:"message"`
TraceId string `json:"trace_id,omitempty"`
}
type TagParser interface {
ParseTag(tag string) (namespace, name string, err error)
}
type DefaultTagParser struct{}
// ParseTag implements the TagParser interface.
// Expects tag strings in the format <hostname>/<namespace>/<imagename:tag>.
func (p *DefaultTagParser) ParseTag(tag string) (namespace, name string, err error) {
parts := strings.Split(tag, "/")
if len(parts) < 3 {
return "", "", fmt.Errorf("invalid tag format: %s", tag)
}
return parts[1], parts[2], nil
}
// NewPrefetchHandler constructs a new PrefetchHandler.
func NewPrefetchHandler(
client blobclient.ClusterClient,
tagClient tagclient.Client,
tagParser TagParser,
metrics tally.Scope,
) *PrefetchHandler {
if tagParser == nil {
tagParser = &DefaultTagParser{}
}
return &PrefetchHandler{
clusterClient: client,
tagClient: tagClient,
tagParser: tagParser,
metrics: metrics.SubScope("prefetch"),
}
}
// newPrefetchSuccessResponse constructs a successful response.
func newPrefetchSuccessResponse(tag, msg, traceId string) *prefetchResponse {
return &prefetchResponse{
Tag: tag,
Prefetched: true,
Status: StatusSuccess,
Message: msg,
TraceId: traceId,
}
}
// newPrefetchError constructs an error response.
func newPrefetchError(status int, msg, traceId string) *prefetchError {
return &prefetchError{
Error: http.StatusText(status),
Prefetched: false,
Status: StatusFailure,
Message: msg,
TraceId: traceId,
}
}
// writeJSON writes the JSON payload with the given HTTP status.
func writeJSON(w http.ResponseWriter, status int, payload interface{}) {
w.Header().Set("Content-Type", "application/json")
response, err := json.Marshal(payload)
if err != nil {
log.With("payload", payload).Errorf("Failed to marshal JSON: %v", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
// Write the status code and the encoded JSON response.
w.WriteHeader(status)
if _, err := w.Write(response); err != nil {
log.With("payload", payload).Errorf("Failed to write response: %v", err)
}
}
func writeBadRequestError(w http.ResponseWriter, msg, traceId string) {
writeJSON(w, http.StatusBadRequest, newPrefetchError(http.StatusBadRequest, msg, traceId))
}
func writeInternalError(w http.ResponseWriter, msg, traceId string) {
writeJSON(w, http.StatusInternalServerError, newPrefetchError(http.StatusInternalServerError, msg, traceId))
}
func writePrefetchResponse(w http.ResponseWriter, tag, msg, traceId string) {
writeJSON(w, http.StatusOK, newPrefetchSuccessResponse(tag, msg, traceId))
}
// Handle processes the prefetch request.
func (ph *PrefetchHandler) Handle(w http.ResponseWriter, r *http.Request) {
ph.metrics.Counter("requests").Inc(1)
var reqBody prefetchBody
if err := json.NewDecoder(r.Body).Decode(&reqBody); err != nil {
writeBadRequestError(w, fmt.Sprintf("failed to decode request body: %s", err), "")
log.With("error", err).Error("Failed to decode request body")
return
}
logger := log.With("trace_id", reqBody.TraceId)
namespace, tag, err := ph.tagParser.ParseTag(reqBody.Tag)
if err != nil {
writeBadRequestError(w, fmt.Sprintf("tag: %s, invalid tag format: %s", reqBody.Tag, err), reqBody.TraceId)
return
}
tagRequest := url.QueryEscape(fmt.Sprintf("%s/%s", namespace, tag))
digest, err := ph.tagClient.Get(tagRequest)
if err != nil {
writeInternalError(w, fmt.Sprintf("tag request: %s, failed to get tag: %s", tagRequest, err), reqBody.TraceId)
return
}
logger.Infof("Namespace: %s, Tag: %s", namespace, tag)
buf := &bytes.Buffer{}
if err := ph.clusterClient.DownloadBlob(namespace, digest, buf); err != nil {
writeInternalError(w, fmt.Sprintf("error downloading manifest blob: %s", err), reqBody.TraceId)
return
}
// Process manifest (ManifestList or single Manifest)
size, digests, err := ph.processManifest(logger, namespace, buf.Bytes())
if err != nil {
writeInternalError(w, fmt.Sprintf("failed to process manifest: %s", err), reqBody.TraceId)
return
}
ph.metrics.SubScope("prefetch").Counter("initiated").Inc(1)
writePrefetchResponse(w, reqBody.Tag, "prefetching initiated successfully", reqBody.TraceId)
// Prefetch blobs asynchronously.
go ph.prefetchBlobs(logger, namespace, digests, size)
}
// prefetchBlobs downloads blobs in parallel.
func (ph *PrefetchHandler) prefetchBlobs(logger *zap.SugaredLogger, namespace string, digests []core.Digest, size int64) {
var wg sync.WaitGroup
var mu sync.Mutex
var errList []error
for _, d := range digests {
wg.Add(1)
go func(digest core.Digest) {
defer wg.Done()
blobStart := time.Now()
err := ph.clusterClient.DownloadBlob(namespace, digest, ioutil.Discard)
blobDuration := time.Since(blobStart)
ph.metrics.Timer("blob_download_time").Record(blobDuration)
ph.metrics.Counter("bytes_downloaded").Inc(size)
if err != nil {
if serr, ok := err.(httputil.StatusError); ok && serr.Status == http.StatusAccepted {
return
}
mu.Lock()
errList = append(errList, fmt.Errorf("digest %s, namespace %s, error downloading blob: %w", digest, namespace, err))
mu.Unlock()
}
}(d)
}
wg.Wait()
if len(errList) > 0 {
ph.metrics.Counter("failed").Inc(1)
for _, err := range errList {
logger.With("error", err).Error("Error downloading blob")
}
}
}
// processManifest handles both ManifestLists and single Manifests.
func (ph *PrefetchHandler) processManifest(logger *zap.SugaredLogger, namespace string, manifestBytes []byte) (int64, []core.Digest, error) {
// Attempt to process as a manifest list.
size, digests, err := ph.tryProcessManifestList(logger, namespace, manifestBytes)
if err == nil && len(digests) > 0 {
return size, digests, nil
}
// Fallback to single manifest.
var manifest schema2.Manifest
if err := json.NewDecoder(bytes.NewReader(manifestBytes)).Decode(&manifest); err != nil {
logger.With("namespace", namespace).Errorf("Failed to parse single manifest: %v", err)
return 0, nil, fmt.Errorf("invalid single manifest: %w", err)
}
return ph.processLayers(manifest.Layers)
}
// tryProcessManifestList attempts to decode a manifest list.
func (ph *PrefetchHandler) tryProcessManifestList(logger *zap.SugaredLogger, namespace string, manifestBytes []byte) (int64, []core.Digest, error) {
var manifestList manifestlist.ManifestList
if err := json.NewDecoder(bytes.NewReader(manifestBytes)).Decode(&manifestList); err != nil || len(manifestList.Manifests) == 0 {
return 0, nil, fmt.Errorf("not a valid manifest list")
}
logger.With("namespace", namespace).Info("Processing manifest list")
return ph.processManifestList(logger, namespace, manifestList)
}
// processManifestList processes a manifest list.
func (ph *PrefetchHandler) processManifestList(logger *zap.SugaredLogger, namespace string, manifestList manifestlist.ManifestList) (int64, []core.Digest, error) {
var allDigests []core.Digest
size := int64(0)
for _, descriptor := range manifestList.Manifests {
manifestDigestHex := descriptor.Digest.Hex()
digest, err := core.NewSHA256DigestFromHex(manifestDigestHex)
if err != nil {
return 0, nil, fmt.Errorf("failed to parse manifest digest %s: %w", manifestDigestHex, err)
}
buf := &bytes.Buffer{}
if err := ph.clusterClient.DownloadBlob(namespace, digest, buf); err != nil {
logger.Errorf("Failed to download manifest blob: %s", err)
continue
}
var manifest schema2.Manifest
if err := json.NewDecoder(buf).Decode(&manifest); err != nil {
return 0, nil, fmt.Errorf("failed to parse manifest: %w", err)
}
l, digests, err := ph.processLayers(manifest.Layers)
if err != nil {
return 0, nil, err
}
size += l
allDigests = append(allDigests, digests...)
}
return size, allDigests, nil
}
// processLayers converts layer descriptors to a list of core.Digest.
func (ph *PrefetchHandler) processLayers(layers []distribution.Descriptor) (int64, []core.Digest, error) {
digests := make([]core.Digest, 0, len(layers))
l := int64(0)
for _, layer := range layers {
digest, err := core.NewSHA256DigestFromHex(layer.Digest.Hex())
if err != nil {
return 0, nil, fmt.Errorf("invalid layer digest: %w", err)
}
digests = append(digests, digest)
l += layer.Size
}
return l, digests, nil
}