pkg/context/context.go (99 lines of code) (raw):
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
package context
import (
"errors"
"fmt"
"net/http"
"os"
"strconv"
"strings"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"github.com/rs/zerolog"
)
const (
// KubeConfigPath is the path of the kubeconfig file.
KubeConfigPath = "/opt/peerd/kubeconfig"
)
// Context keys.
const (
CorrelationIdCtxKey = "correlation_id"
DigestCtxKey = "digest"
FileChunkCtxKey = "file_chunk"
BlobUrlCtxKey = "blob_url"
BlobRangeCtxKey = "blob_range"
NamespaceCtxKey = "namespace"
ReferenceCtxKey = "reference"
RefTypeCtxKey = "ref_type"
LoggerCtxKey = "logger"
)
// Request headers.
const (
P2PHeaderKey = "X-MS-Peerd-RequestFromPeer"
CorrelationHeaderKey = "X-MS-Peerd-CorrelationId"
NodeHeaderKey = "X-MS-Peerd-Node"
)
// Log messages.
const (
PeerResolutionStartLog = "peer resolution start"
PeerResolutionStopLog = "peer resolution stop"
PeerNotFoundLog = "peer not found"
PeerResolutionExhaustedLog = "peer resolution exhausted"
PeerRequestErrorLog = "peer request error"
)
var (
NodeName, _ = os.Hostname()
)
// Context is the request context that can be passed around to various components to provide request specific information.
type Context struct {
*gin.Context
}
// FromContext creates a new context from the given gin context.
func FromContext(c *gin.Context) Context {
return Context{Context: c}
}
// Copy creates a copy of the context that can be safely used outside the request's scope.
func (c Context) Copy() Context {
cc := c.Context.Copy()
return Context{Context: cc}
}
// IsRequestFromAPeer indicates if the current request is from a peer.
func IsRequestFromAPeer(c Context) bool {
return c.Request.Header.Get(P2PHeaderKey) == "true"
}
// FillCorrelationId fills the correlation ID in the context.
func FillCorrelationId(c Context) {
correlationId := c.Request.Header.Get(CorrelationHeaderKey)
if correlationId == "" {
correlationId = uuid.New().String()
}
c.Set(CorrelationIdCtxKey, correlationId)
}
// SetOutboundHeaders sets the mandatory headers for all outbound requests.
func SetOutboundHeaders(r *http.Request, c Context) {
r.Header.Set(P2PHeaderKey, "true")
r.Header.Set(CorrelationHeaderKey, c.GetString(CorrelationIdCtxKey))
r.Header.Set(NodeHeaderKey, NodeName)
}
// Logger gets the logger with request specific fields.
func Logger(c Context) zerolog.Logger {
var l zerolog.Logger
obj, ok := c.Get(LoggerCtxKey)
if !ok {
fmt.Println("WARN: logger not found in context")
l = zerolog.Nop()
} else {
ctxLog := obj.(*zerolog.Logger)
l = *ctxLog
}
return l.With().Str("correlationid", c.GetString(CorrelationIdCtxKey)).Str("url", c.Request.URL.String()).Str("range", c.Request.Header.Get("Range")).Bool("requestfrompeer", IsRequestFromAPeer(c)).Str("clientip", c.ClientIP()).Str("clientname", c.Request.Header.Get(NodeHeaderKey)).Logger()
}
// BlobUrl extracts the blob URL from the incoming request URL.
func BlobUrl(c Context) string {
return strings.TrimPrefix(c.Param("url"), "/") + "?" + c.Request.URL.RawQuery
}
// RangeStartIndex returns the start index of a byte range specified in the given range header value.
// It expects the range value to be in the format "bytes=startIndex-endIndex".
func RangeStartIndex(rangeValue string) (int64, error) {
if rangeValue == "" {
return 0, errors.New("no range header")
}
// split the range value by "="
parts := strings.Split(rangeValue, "=")
if len(parts) != 2 || parts[0] != "bytes" {
return 0, errors.New("invalid range format")
}
// split the byte range by "-"
ranges := strings.Split(parts[1], "-")
if len(ranges) != 2 {
return 0, errors.New("invalid range format")
}
// convert the start index to an integer
startIndex, err := strconv.Atoi(ranges[0])
if err != nil {
return 0, err
}
return int64(startIndex), nil
}