registry/storage/driver/s3-aws/common/parser.go (505 lines of code) (raw):

package common import ( "errors" "fmt" "math" "net/http" "slices" "strings" "github.com/aws/aws-sdk-go-v2/service/s3/types" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/endpoints" "github.com/aws/aws-sdk-go/service/s3" dcontext "github.com/docker/distribution/context" "github.com/docker/distribution/registry/storage/driver" "github.com/docker/distribution/registry/storage/driver/internal/parse" "github.com/hashicorp/go-multierror" v2_aws "github.com/aws/aws-sdk-go-v2/aws" ) const ( // Authentication parameters ParamAccessKey = "accesskey" ParamSecretKey = "secretkey" ParamSessionToken = "sessiontoken" // Region configuration ParamRegion = "region" ParamRegionEndpoint = "regionendpoint" // Bucket configuration ParamBucket = "bucket" ParamRootDirectory = "rootdirectory" ParamStorageClass = "storageclass" ParamObjectACL = "objectacl" ParamObjectOwnership = "objectownership" // Security and encryption ParamEncrypt = "encrypt" ParamKeyID = "keyid" ParamSecure = "secure" ParamSkipVerify = "skipverify" ParamV4Auth = "v4auth" // Chunk size configuration ParamChunkSize = "chunksize" ParamMultipartCopyChunkSize = "multipartcopychunksize" ParamMultipartCopyMaxConcurrency = "multipartcopymaxconcurrency" ParamMultipartCopyThresholdSize = "multipartcopythresholdsize" // Request configuration ParamMaxRequestsPerSecond = "maxrequestspersecond" ParamMaxRetries = "maxretries" // Path and logging configuration ParamPathStyle = "pathstyle" ParamParallelWalk = "parallelwalk" ParamLogLevel = "loglevel" // Log level values // * common for v1 and v2: LogLevelOff = "logoff" // * v1-specific LogLevelDebug = "logdebug" LogLevelDebugWithSigning = "logdebugwithsigning" LogLevelDebugWithHTTPBody = "logdebugwithhttpbody" LogLevelDebugWithRequestRetries = "logdebugwithrequestretries" LogLevelDebugWithRequestErrors = "logdebugwithrequesterrors" LogLevelDebugWithEventStreamBody = "logdebugwitheventstreambody" // * v2-specific LogSigning = "logsigning" LogRetries = "logretries" LogRequest = "logrequest" LogRequestWithBody = "logrequestwithbody" LogResponse = "logresponse" LogResponseWithBody = "logresponsewithbody" LogDeprecatedUsage = "logdeprecatedusage" LogRequestEventMessage = "logrequesteventmessage" LogResponseEventMessage = "logresponseeventmessage" // Storage class values StorageClassNone = "NONE" // Content verification ParamChecksumDisabled = "checksum_disabled" ParamChecksumAlgorithm = "checksum_algorithm" ) const ( // EnvDriverVersion defines the version of the S3 storage driver to use EnvDriverVersion = "S3_DRIVER_VERSION" // EnvAccessKey defines the AWS access key for S3 authentication EnvAccessKey = "AWS_ACCESS_KEY" // EnvSecretKey defines the AWS secret key for S3 authentication EnvSecretKey = "AWS_SECRET_KEY" // nolint: gosec // this is just and env var name // EnvBucket defines the target S3 bucket name EnvBucket = "S3_BUCKET" // EnvEncrypt enables server-side encryption for S3 objects EnvEncrypt = "S3_ENCRYPT" // EnvKeyID specifies the KMS key ID for server-side encryption EnvKeyID = "S3_KEY_ID" // EnvSecure enables HTTPS for S3 connections EnvSecure = "S3_SECURE" // EnvSkipVerify disables SSL certificate verification EnvSkipVerify = "S3_SKIP_VERIFY" // EnvV4Auth is used to disable AWS Signature Version 4 authentication EnvV4Auth = "S3_V4_AUTH" // EnvRegion specifies the AWS region for S3 operations EnvRegion = "AWS_REGION" // EnvObjectACL defines the access control list for S3 objects EnvObjectACL = "S3_OBJECT_ACL" // EnvRegionEndpoint specifies a custom S3 endpoint URL EnvRegionEndpoint = "REGION_ENDPOINT" // EnvSessionToken provides temporary AWS credentials EnvSessionToken = "AWS_SESSION_TOKEN" // nolint: gosec // this is just and env var name // EnvPathStyle enables path-style S3 URLs instead of virtual-hosted-style EnvPathStyle = "AWS_PATH_STYLE" // EnvMaxRequestsPerSecond limits the rate of S3 API requests EnvMaxRequestsPerSecond = "S3_MAX_REQUESTS_PER_SEC" // EnvMaxRetries specifies the maximum number of retry attempts for failed S3 operations EnvMaxRetries = "S3_MAX_RETRIES" // EnvLogLevel sets the logging verbosity for S3 operations EnvLogLevel = "S3_LOG_LEVEL" // EnvObjectOwnership configures the object ownership settings for the S3 bucket EnvObjectOwnership = "S3_OBJECT_OWNERSHIP" // EnvChecksumAlgorithm specifies the algorithm to use for checksums EnvChecksumDisabled = "S3_CHECKSUM_DISABLED" EnvChecksumAlgorithm = "S3_CHECKSUM_ALGORITHM" ) // validRegions maps known s3 region identifiers to region descriptors var validRegions = make(map[string]struct{}) var validStorageClassesV1 = []string{ StorageClassNone, s3.StorageClassStandard, s3.StorageClassReducedRedundancy, s3.StorageClassStandardIa, s3.StorageClassOnezoneIa, s3.StorageClassIntelligentTiering, s3.StorageClassOutposts, s3.StorageClassGlacierIr, s3.StorageClassExpressOnezone, } var validStorageClassesV2 = []string{ StorageClassNone, string(types.StorageClassStandard), string(types.StorageClassReducedRedundancy), string(types.StorageClassStandardIa), string(types.StorageClassOnezoneIa), string(types.StorageClassIntelligentTiering), string(types.StorageClassOutposts), string(types.StorageClassGlacierIr), string(types.StorageClassExpressOnezone), } func init() { partitions := endpoints.DefaultPartitions() for _, p := range partitions { for region := range p.Regions() { validRegions[region] = struct{}{} } } } // DriverParameters A struct that encapsulates all of the driver parameters // after all values have been set type DriverParameters struct { AccessKey string SecretKey string Bucket string Region string RegionEndpoint string Encrypt bool KeyID string Secure bool SkipVerify bool ChunkSize int64 MultipartCopyChunkSize int64 MultipartCopyMaxConcurrency int MultipartCopyThresholdSize int64 RootDirectory string StorageClass string ObjectACL string SessionToken string PathStyle bool MaxRequestsPerSecond int64 MaxRetries int64 ParallelWalk bool Logger dcontext.Logger // In order to keep the code dry, we reuse the same struct field and store // the result in a wider (i.e. uint64 instead of uint) type to accommodate // both sdks. LogLevel uint64 ObjectOwnership bool // v1 specific: V4Auth bool S3APIImpl S3WrapperIf // v2 specific: ChecksumDisabled bool ChecksumAlgorithm types.ChecksumAlgorithm Transport http.RoundTripper } // ParseLogLevelParamV1 parses given loglevel into a value that sdk v1 accepts. // The parameter itself is a comma-separated list of flags/loglevels that user // wants to enable. func ParseLogLevelParamV1(logger dcontext.Logger, param any) aws.LogLevelType { if param == nil { logger.Debugf("S3 logging level is not set, defaulting to %q", LogLevelOff) return aws.LogOff } if ll, ok := param.(aws.LogLevelType); ok { return ll } var res aws.LogLevelType var logLevelsSet []string for _, v := range strings.Split(strings.ToLower(param.(string)), ",") { switch v { case LogLevelOff: // LogLevelOff in the list of loglevels overrides all other log // levels and disables logging logger.Debugf("S3 logging level set to %q", LogLevelOff) return aws.LogOff case LogLevelDebug: res |= aws.LogDebug logLevelsSet = append(logLevelsSet, LogLevelDebug) case LogLevelDebugWithSigning: res |= aws.LogDebugWithSigning logLevelsSet = append(logLevelsSet, LogLevelDebugWithSigning) case LogLevelDebugWithHTTPBody: res |= aws.LogDebugWithHTTPBody logLevelsSet = append(logLevelsSet, LogLevelDebugWithHTTPBody) case LogLevelDebugWithRequestRetries: res |= aws.LogDebugWithRequestRetries logLevelsSet = append(logLevelsSet, LogLevelDebugWithRequestRetries) case LogLevelDebugWithRequestErrors: res |= aws.LogDebugWithRequestErrors logLevelsSet = append(logLevelsSet, LogLevelDebugWithRequestErrors) case LogLevelDebugWithEventStreamBody: res |= aws.LogDebugWithEventStreamBody logLevelsSet = append(logLevelsSet, LogLevelDebugWithEventStreamBody) // Check for v2 log levels that shouldn't be used with v1 case LogSigning, LogRetries, LogRequest, LogRequestWithBody, LogResponse, LogResponseWithBody, LogDeprecatedUsage, LogRequestEventMessage, LogResponseEventMessage: logger.Warnf("S3 driver v2 log level %q has been passed to S3 driver v1. Ignoring. Please adjust your configuration", v) default: logger.Warnf("unknown loglevel %q, S3 logging level set to %q", param, LogLevelOff) return aws.LogOff } } logger.Infof("S3 logging level set to %q", strings.Join(logLevelsSet, ",")) return res } // ParseLogLevelParamV2 parses given loglevel into a value that sdk v2 accepts. // The parameter itself is a comma-separated list of flags/loglevels that user // wants to enable. func ParseLogLevelParamV2(logger dcontext.Logger, param any) v2_aws.ClientLogMode { if param == nil { logger.Debugf("S3 logging level is not set, defaulting to %q", LogLevelOff) // aws sdk v2 does not have a constant for this: return v2_aws.ClientLogMode(0) } if ll, ok := param.(v2_aws.ClientLogMode); ok { return ll } var res v2_aws.ClientLogMode var logLevelsSet []string for _, v := range strings.Split(strings.ToLower(param.(string)), ",") { switch v { // LogLevelOff in the list of loglevels overrides all other log levels // and disables logging case LogLevelOff: logger.Debugf("S3 logging level set to %q", LogLevelOff) // aws sdk v2 does not have a constant for this: return v2_aws.ClientLogMode(0) case LogSigning: res |= v2_aws.LogSigning logLevelsSet = append(logLevelsSet, LogSigning) case LogRetries: res |= v2_aws.LogRetries logLevelsSet = append(logLevelsSet, LogRetries) case LogRequest: res |= v2_aws.LogRequest logLevelsSet = append(logLevelsSet, LogRequest) case LogRequestWithBody: res |= v2_aws.LogRequestWithBody logLevelsSet = append(logLevelsSet, LogRequestWithBody) case LogResponse: res |= v2_aws.LogResponse logLevelsSet = append(logLevelsSet, LogResponse) case LogResponseWithBody: res |= v2_aws.LogResponseWithBody logLevelsSet = append(logLevelsSet, LogResponseWithBody) case LogDeprecatedUsage: res |= v2_aws.LogDeprecatedUsage logLevelsSet = append(logLevelsSet, LogDeprecatedUsage) case LogRequestEventMessage: res |= v2_aws.LogRequestEventMessage logLevelsSet = append(logLevelsSet, LogRequestEventMessage) case LogResponseEventMessage: res |= v2_aws.LogResponseEventMessage logLevelsSet = append(logLevelsSet, LogResponseEventMessage) // Check for v1 log levels that shouldn't be used with v2 case LogLevelDebug, LogLevelDebugWithSigning, LogLevelDebugWithHTTPBody, LogLevelDebugWithRequestRetries, LogLevelDebugWithRequestErrors, LogLevelDebugWithEventStreamBody: logger.Warnf("S3 driver v1 log level %q has been passed to S3 driver v2. Ignoring. Please adjust your configuration", v) default: logger.Warnf("unknown loglevel %q, S3 logging level set to %q", param, LogLevelOff) return v2_aws.ClientLogMode(0) } } logger.Infof("S3 logging level set to %q", strings.Join(logLevelsSet, ",")) return res } func ParseParameters(driverVersion string, parameters map[string]any) (*DriverParameters, error) { var mErr *multierror.Error res := new(DriverParameters) // Providing no values for these is valid in case the user is authenticating // with an IAM on an ec2 instance (in which case the instance credentials will // be summoned when GetAuth is called) accessKey := parameters[ParamAccessKey] if accessKey == nil { accessKey = "" } res.AccessKey = fmt.Sprint(accessKey) secretKey := parameters[ParamSecretKey] if secretKey == nil { // nolint: gosec // G101 -- This is a false positive secretKey = "" } res.SecretKey = fmt.Sprint(secretKey) regionEndpoint := parameters[ParamRegionEndpoint] if regionEndpoint == nil { regionEndpoint = "" } res.RegionEndpoint = fmt.Sprint(regionEndpoint) regionName := parameters[ParamRegion] if regionName == nil || fmt.Sprint(regionName) == "" { err := fmt.Errorf("no %q parameter provided", ParamRegion) mErr = multierror.Append(mErr, err) } region := fmt.Sprint(regionName) res.Region = region // Don't check the region value if a custom endpoint is provided. if regionEndpoint == "" { if _, ok := validRegions[region]; !ok { err := fmt.Errorf("validating region provided: %v", region) mErr = multierror.Append(mErr, err) } } bucket := parameters[ParamBucket] if bucket == nil || fmt.Sprint(bucket) == "" { err := errors.New("no bucket parameter provided") mErr = multierror.Append(mErr, err) } res.Bucket = fmt.Sprint(bucket) encryptEnable, err := parse.Bool(parameters, ParamEncrypt, false) if err != nil { mErr = multierror.Append(mErr, err) } res.Encrypt = encryptEnable secureEnable, err := parse.Bool(parameters, ParamSecure, true) if err != nil { mErr = multierror.Append(mErr, err) } res.Secure = secureEnable skipVerifyEnable, err := parse.Bool(parameters, ParamSkipVerify, false) if err != nil { mErr = multierror.Append(mErr, err) } res.SkipVerify = skipVerifyEnable v4Enable, err := parse.Bool(parameters, ParamV4Auth, true) if err != nil { mErr = multierror.Append(mErr, err) } res.V4Auth = v4Enable keyID := parameters[ParamKeyID] if keyID == nil { keyID = "" } res.KeyID = fmt.Sprint(keyID) chunkSize, err := parse.Int64( parameters, ParamChunkSize, // NOTE(prozlach): We are using two buffers, each one can hold up to // chunksize bytes, and in some circumstances their contents may be // concatenated and committed as a single part. For example: // // https://gitlab.com/gitlab-org/container-registry/-/blob/0604f5b44093b9647dcbf5f4f7a0d6ab824ff0a4/registry/storage/driver/s3-aws/v2/s3.go?page=2#L1447-L1450 // // We need to halve the limit in order to not to exceed MaxChunkSize. DefaultChunkSize, MinChunkSize, MaxChunkSize/2, ) if err != nil { err := fmt.Errorf("converting %q to int64: %w", ParamChunkSize, err) mErr = multierror.Append(mErr, err) } res.ChunkSize = chunkSize multipartCopyChunkSize, err := parse.Int64( parameters, ParamMultipartCopyChunkSize, DefaultMultipartCopyChunkSize, MinChunkSize, MaxChunkSize, ) if err != nil { err := fmt.Errorf("converting %q to valid int64: %w", ParamMultipartCopyChunkSize, err) mErr = multierror.Append(mErr, err) } res.MultipartCopyChunkSize = multipartCopyChunkSize multipartCopyMaxConcurrency, err := parse.Int32( parameters, ParamMultipartCopyMaxConcurrency, DefaultMultipartCopyMaxConcurrency, 1, math.MaxInt32, ) if err != nil { err := fmt.Errorf("converting %q to valid int64: %w", ParamMultipartCopyMaxConcurrency, err) mErr = multierror.Append(mErr, err) } res.MultipartCopyMaxConcurrency = int(multipartCopyMaxConcurrency) multipartCopyThresholdSize, err := parse.Int64( parameters, ParamMultipartCopyThresholdSize, DefaultMultipartCopyThresholdSize, 0, MaxChunkSize, ) if err != nil { err := fmt.Errorf("converting %q to valid int64: %w", ParamMultipartCopyThresholdSize, err) mErr = multierror.Append(mErr, err) } res.MultipartCopyThresholdSize = multipartCopyThresholdSize rootDirectory := parameters[ParamRootDirectory] if rootDirectory == nil { rootDirectory = "" } res.RootDirectory = fmt.Sprint(rootDirectory) var storageClass string storageClassParam := parameters[ParamStorageClass] if storageClassParam != nil { storageClassString, ok := storageClassParam.(string) switch { case !ok: err := fmt.Errorf("the storageclass parameter must be a string: %v", storageClassParam) mErr = multierror.Append(mErr, err) case (driverVersion == V1DriverName || driverVersion == V1DriverNameAlt): // All valid storage class parameters are UPPERCASE, so be a bit more flexible here storageClassString = strings.ToUpper(storageClassString) if !slices.Contains(validStorageClassesV1, storageClassString) { err := fmt.Errorf( "the storageclass parameter must be one of %v, %v is invalid", strings.Join(validStorageClassesV1, ","), storageClassParam, ) mErr = multierror.Append(mErr, err) } else { storageClass = storageClassString } case driverVersion == V2DriverName: storageClassString = strings.ToUpper(storageClassString) if !slices.Contains(validStorageClassesV2, storageClassString) { err := fmt.Errorf( "the storageclass parameter must be one of %v, %v is invalid", strings.Join(validStorageClassesV2, ","), storageClassParam, ) mErr = multierror.Append(mErr, err) } else { storageClass = storageClassString } default: storageClass = storageClassString } } else { switch { case (driverVersion == V1DriverName || driverVersion == V1DriverNameAlt): storageClass = string(s3.StorageClassStandard) case driverVersion == V2DriverName: storageClass = string(types.StorageClassStandard) } } res.StorageClass = storageClass // Parse checksum_disabled parameter checksumDisabled, err := parse.Bool(parameters, ParamChecksumDisabled, false) if err != nil { mErr = multierror.Append(mErr, err) } res.ChecksumDisabled = checksumDisabled // Parse checksum algorithm defaultChecksumAlgorithm := types.ChecksumAlgorithmCrc64nvme checksumAlgorithmParam := parameters[ParamChecksumAlgorithm] if checksumDisabled { // If checksum_disabled is true, ignore checksum_algorithm if checksumAlgorithmParam != nil { logger := parameters[driver.ParamLogger].(dcontext.Logger) logger.Warnf("Both %s and %s parameters provided, %s takes precedence", ParamChecksumDisabled, ParamChecksumAlgorithm, ParamChecksumDisabled) } defaultChecksumAlgorithm = "" } else if checksumAlgorithmParam != nil { checksumAlgorithm, ok := checksumAlgorithmParam.(string) if !ok { err := fmt.Errorf("the checksum_algorithm parameter must be a string: %v", checksumAlgorithmParam) mErr = multierror.Append(mErr, err) } else { // Convert to uppercase for consistency and check if it's valid checksumAlgorithmTyped := (types.ChecksumAlgorithm)(strings.ToUpper(checksumAlgorithm)) // nolint: revive // max-control-nesting if !slices.Contains(checksumAlgorithmTyped.Values(), checksumAlgorithmTyped) { err := fmt.Errorf("the checksum_algorithm parameter must be one of %v, %q is invalid", checksumAlgorithmTyped.Values(), checksumAlgorithmParam) mErr = multierror.Append(mErr, err) } else { defaultChecksumAlgorithm = checksumAlgorithmTyped } } } res.ChecksumAlgorithm = defaultChecksumAlgorithm objectOwnership, err := parse.Bool(parameters, ParamObjectOwnership, false) if err != nil { mErr = multierror.Append(mErr, err) } res.ObjectOwnership = objectOwnership var objectACL string objectACLParam := parameters[ParamObjectACL] if objectACLParam != nil { if objectOwnership { err := fmt.Errorf("object ACL parameter should not be set when object ownership is enabled") mErr = multierror.Append(mErr, err) } else { objectACLString, ok := objectACLParam.(string) switch { case !ok: err := fmt.Errorf("object ACL parameter should be a string: %v", objectACLParam) mErr = multierror.Append(mErr, err) case (driverVersion == V1DriverName || driverVersion == V1DriverNameAlt) && !slices.Contains(s3.ObjectCannedACL_Values(), objectACLString): err := fmt.Errorf("object ACL parameter should be one of %v: %v", strings.Join(s3.ObjectCannedACL_Values(), ","), objectACLParam) mErr = multierror.Append(mErr, err) case driverVersion == V2DriverName && !slices.Contains(types.ObjectCannedACLPrivate.Values(), (types.ObjectCannedACL)(objectACLString)): // typecast: strValues := make([]string, len(types.ObjectCannedACLPrivate.Values())) for i, v := range types.ObjectCannedACLPrivate.Values() { strValues[i] = string(v) } err := fmt.Errorf("object ACL parameter should be one of %v: %v", strings.Join(strValues, ","), objectACLParam) mErr = multierror.Append(mErr, err) default: objectACL = objectACLString } } } else { switch { case (driverVersion == V1DriverName || driverVersion == V1DriverNameAlt): objectACL = string(s3.ObjectCannedACLPrivate) case driverVersion == V2DriverName: objectACL = string(types.ObjectCannedACLPrivate) } } res.ObjectACL = objectACL // If regionEndpoint is set, default to forcing pathstyle to preserve legacy behavior. defaultPathStyle := regionEndpoint != "" pathStyleBool, err := parse.Bool(parameters, ParamPathStyle, defaultPathStyle) if err != nil { mErr = multierror.Append(mErr, err) } res.PathStyle = pathStyleBool parallelWalkBool, err := parse.Bool(parameters, ParamParallelWalk, false) if err != nil { mErr = multierror.Append(mErr, err) } res.ParallelWalk = parallelWalkBool maxRequestsPerSecond, err := parse.Int64(parameters, ParamMaxRequestsPerSecond, DefaultMaxRequestsPerSecond, 0, math.MaxInt64) if err != nil { err = fmt.Errorf("converting maxrequestspersecond to valid int64: %w", err) mErr = multierror.Append(mErr, err) } res.MaxRequestsPerSecond = maxRequestsPerSecond maxRetries, err := parse.Int64( parameters, ParamMaxRetries, DefaultMaxRetries, 0, math.MaxInt64, ) if err != nil { err := fmt.Errorf("converting maxrequestspersecond to valid int64: %w", err) mErr = multierror.Append(mErr, err) } res.MaxRetries = maxRetries if err := mErr.ErrorOrNil(); err != nil { return nil, err } logger := parameters[driver.ParamLogger].(dcontext.Logger) res.Logger = logger switch driverVersion { case V1DriverName, V1DriverNameAlt: res.LogLevel = uint64(ParseLogLevelParamV1(logger, parameters[ParamLogLevel])) case V2DriverName: res.LogLevel = uint64(ParseLogLevelParamV2(logger, parameters[ParamLogLevel])) } return res, nil }