lib/backend/s3backend/client.go (224 lines of code) (raw):
// Copyright (c) 2016-2019 Uber Technologies, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package s3backend
import (
"errors"
"fmt"
"io"
"path"
"github.com/uber-go/tally"
"github.com/uber/kraken/core"
"github.com/uber/kraken/lib/backend"
"github.com/uber/kraken/lib/backend/backenderrors"
"github.com/uber/kraken/lib/backend/namepath"
"github.com/uber/kraken/utils/log"
"github.com/uber/kraken/utils/rwutil"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/aws/aws-sdk-go/service/s3/s3manager"
"gopkg.in/yaml.v2"
"go.uber.org/zap"
)
const _s3 = "s3"
func init() {
backend.Register(_s3, &factory{})
}
type factory struct{}
func (f *factory) Create(
confRaw interface{}, masterAuthConfig backend.AuthConfig, stats tally.Scope, _ *zap.SugaredLogger) (backend.Client, error) {
confBytes, err := yaml.Marshal(confRaw)
if err != nil {
return nil, errors.New("marshal s3 config")
}
authConfBytes, err := yaml.Marshal(masterAuthConfig[_s3])
if err != nil {
return nil, errors.New("marshal s3 auth config")
}
var config Config
if err := yaml.Unmarshal(confBytes, &config); err != nil {
return nil, errors.New("unmarshal s3 config")
}
var userAuth UserAuthConfig
if err := yaml.Unmarshal(authConfBytes, &userAuth); err != nil {
return nil, errors.New("unmarshal s3 auth config")
}
return NewClient(config, userAuth, stats)
}
// Client implements a backend.Client for S3.
type Client struct {
config Config
pather namepath.Pather
stats tally.Scope
s3 S3
}
// Option allows setting optional Client parameters.
type Option func(*Client)
// WithS3 configures a Client with a custom S3 implementation.
func WithS3(s3 S3) Option {
return func(c *Client) { c.s3 = s3 }
}
// NewClient creates a new Client for S3.
func NewClient(
config Config, userAuth UserAuthConfig, stats tally.Scope, opts ...Option) (*Client, error) {
config.applyDefaults()
if config.Username == "" {
return nil, errors.New("invalid config: username required")
}
if config.Region == "" {
return nil, errors.New("invalid config: region required")
}
if config.Bucket == "" {
return nil, errors.New("invalid config: bucket required")
}
if !path.IsAbs(config.RootDirectory) {
return nil, errors.New("invalid config: root_directory must be absolute path")
}
pather, err := namepath.New(config.RootDirectory, config.NamePath)
if err != nil {
return nil, fmt.Errorf("namepath: %s", err)
}
auth, ok := userAuth[config.Username]
if !ok {
return nil, errors.New("auth not configured for username")
}
creds := credentials.NewStaticCredentials(
auth.S3.AccessKeyID, auth.S3.AccessSecretKey, auth.S3.SessionToken)
awsConfig := aws.NewConfig().WithRegion(config.Region).WithCredentials(creds)
if config.Endpoint != "" {
awsConfig = awsConfig.WithEndpoint(config.Endpoint)
}
if config.DisableSSL {
awsConfig = awsConfig.WithDisableSSL(config.DisableSSL)
}
if config.S3ForcePathStyle {
awsConfig = awsConfig.WithS3ForcePathStyle(config.S3ForcePathStyle)
}
api := s3.New(session.New(), awsConfig)
downloader := s3manager.NewDownloaderWithClient(api, func(d *s3manager.Downloader) {
d.PartSize = config.DownloadPartSize
d.Concurrency = config.DownloadConcurrency
})
uploader := s3manager.NewUploaderWithClient(api, func(u *s3manager.Uploader) {
u.PartSize = config.UploadPartSize
u.Concurrency = config.UploadConcurrency
})
client := &Client{config, pather, stats, join{api, downloader, uploader}}
for _, opt := range opts {
opt(client)
}
return client, nil
}
// Stat returns blob info for name.
func (c *Client) Stat(namespace, name string) (*core.BlobInfo, error) {
path, err := c.pather.BlobPath(name)
if err != nil {
return nil, fmt.Errorf("blob path: %s", err)
}
output, err := c.s3.HeadObject(&s3.HeadObjectInput{
Bucket: aws.String(c.config.Bucket),
Key: aws.String(path),
})
if err != nil {
if isNotFound(err) {
return nil, backenderrors.ErrBlobNotFound
}
return nil, err
}
var size int64
if output.ContentLength != nil {
size = *output.ContentLength
}
return core.NewBlobInfo(size), nil
}
// Download downloads the content from a configured bucket and writes the
// data to dst.
func (c *Client) Download(namespace, name string, dst io.Writer) error {
path, err := c.pather.BlobPath(name)
if err != nil {
return fmt.Errorf("blob path: %s", err)
}
// The S3 download API uses io.WriterAt to perform concurrent chunked download.
// We attempt to upcast dst to io.WriterAt for this purpose, else we download into
// in-memory buffer and drain it into dst after the download is finished.
writerAt, ok := dst.(io.WriterAt)
if !ok {
writerAt = rwutil.NewCappedBuffer(int(c.config.BufferGuard))
}
input := &s3.GetObjectInput{
Bucket: aws.String(c.config.Bucket),
Key: aws.String(path),
}
if _, err := c.s3.Download(writerAt, input); err != nil {
if isNotFound(err) {
return backenderrors.ErrBlobNotFound
}
return err
}
if capBuf, ok := writerAt.(*rwutil.CappedBuffer); ok {
if err = capBuf.DrainInto(dst); err != nil {
return err
}
}
return nil
}
// Upload uploads src to a configured bucket.
func (c *Client) Upload(namespace, name string, src io.Reader) error {
path, err := c.pather.BlobPath(name)
if err != nil {
return fmt.Errorf("blob path: %s", err)
}
input := &s3manager.UploadInput{
Bucket: aws.String(c.config.Bucket),
Key: aws.String(path),
Body: src,
}
_, err = c.s3.Upload(input, func(u *s3manager.Uploader) {
u.LeavePartsOnError = false // Delete the parts if the upload fails.
})
return err
}
func isNotFound(err error) bool {
awsErr, ok := err.(awserr.Error)
return ok && (awsErr.Code() == s3.ErrCodeNoSuchKey || awsErr.Code() == "NotFound")
}
// List lists names with start with prefix.
func (c *Client) List(prefix string, opts ...backend.ListOption) (*backend.ListResult, error) {
// For whatever reason, the S3 list API does not accept an absolute path
// for prefix. Thus, the root is stripped from the input and added manually
// to each output key.
options := backend.DefaultListOptions()
for _, opt := range opts {
opt(options)
}
// If paginiated is enabled use the maximum number of keys requests from thhe options,
// otherwise fall back to the configuration's max keys
maxKeys := int64(c.config.ListMaxKeys)
var continuationToken *string
if options.Paginated {
maxKeys = int64(options.MaxKeys)
// An empty continuationToken should be left as nil when sending paginated list
// requests to s3
if options.ContinuationToken != "" {
continuationToken = aws.String(options.ContinuationToken)
}
}
var names []string
nextContinuationToken := ""
err := c.s3.ListObjectsV2Pages(&s3.ListObjectsV2Input{
Bucket: aws.String(c.config.Bucket),
MaxKeys: aws.Int64(maxKeys),
Prefix: aws.String(path.Join(c.pather.BasePath(), prefix)[1:]),
ContinuationToken: continuationToken,
}, func(page *s3.ListObjectsV2Output, last bool) bool {
for _, object := range page.Contents {
if object.Key == nil {
log.With(
"prefix", prefix,
"object", object).Error("List encountered nil S3 object key")
continue
}
name, err := c.pather.NameFromBlobPath(path.Join("/", *object.Key))
if err != nil {
log.With("key", *object.Key).Errorf("Error converting blob path into name: %s", err)
continue
}
names = append(names, name)
}
if int64(len(names)) < maxKeys {
// Continue iterating pages to get more keys
return true
}
// Attempt to capture the continuation token before we stop iterating pages
if page.IsTruncated != nil && *page.IsTruncated && page.NextContinuationToken != nil {
nextContinuationToken = *page.NextContinuationToken
}
return false
})
if err != nil {
return nil, err
}
return &backend.ListResult{
Names: names,
ContinuationToken: nextContinuationToken,
}, nil
}