sharedlibraries/storage/multipart.go (357 lines of code) (raw):
/*
Copyright 2024 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package storage
import (
"bytes"
"context"
"encoding/xml"
"fmt"
"io"
"net/http"
"os"
"strings"
"sync"
"time"
"cloud.google.com/go/storage"
"google.golang.org/api/googleapi"
"golang.org/x/oauth2/google"
"golang.org/x/oauth2"
"github.com/GoogleCloudPlatform/workloadagentplatform/sharedlibraries/log"
)
const (
defaultClientEndpoint = "storage.googleapis.com"
tokenScope = "https://www.googleapis.com/auth/cloud-platform"
)
var (
defaultNewClient = func(timeout time.Duration, trans *http.Transport) httpClient {
return &http.Client{Timeout: timeout, Transport: trans}
}
defaultTransport = func() *http.Transport {
return &http.Transport{
Proxy: http.ProxyFromEnvironment,
MaxConnsPerHost: 100,
ForceAttemptHTTP2: true,
MaxIdleConns: 100,
IdleConnTimeout: 10 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
}
}
)
type httpClient interface {
Do(req *http.Request) (*http.Response, error)
}
// HTTPClient abstracts creating a new HTTP client to connect to GCS.
type HTTPClient func(timeout time.Duration, trans *http.Transport) httpClient
// DefaultTokenGetter abstracts obtaining a default oauth2 token source.
type DefaultTokenGetter func(context.Context, ...string) (oauth2.TokenSource, error)
// JSONCredentialsGetter abstracts obtaining JSON oauth2 google credentials.
type JSONCredentialsGetter func(context.Context, []byte, ...string) (*google.Credentials, error)
// objectPart stores part numbers and ETag data for the final XML request.
type objectPart struct {
PartNumber int64 `xml:"PartNumber"`
ETag string `xml:"ETag"`
}
// MultipartWriter is a writer following GCS multipart upload protocol.
type MultipartWriter struct {
bucket *storage.BucketHandle
token *oauth2.Token
httpClient httpClient
objectName string
fileType string
baseURL string
storageClass string
uploadID string
partSizeBytes int64
partNum int64
uploadErr error
maxRetries int64
retryBackoffInitial time.Duration
retryBackoffMax time.Duration
retryBackoffMultiplier float64
mu *sync.Mutex
parts map[int64]objectPart
workers []*uploadWorker
currentWorker *uploadWorker
idleWorkers chan *uploadWorker
customTime time.Time
retentionMode string
retentionTime time.Time
}
// uploadWorker will buffer and retry uploading a single part.
type uploadWorker struct {
w *MultipartWriter
httpClient httpClient
buffer []byte
offset int64
numRetries int64
}
// NewMultipartWriter creates a writer and workers for a multipart upload.
func (rw *ReadWriter) NewMultipartWriter(ctx context.Context, newClient HTTPClient, tokenGetter DefaultTokenGetter, jsonCredentialsGetter JSONCredentialsGetter) (*MultipartWriter, error) {
if rw.XMLMultipartEndpoint == "" {
rw.XMLMultipartEndpoint = defaultClientEndpoint
}
baseURL := fmt.Sprintf("https://%s.%s/%s", rw.BucketName, rw.XMLMultipartEndpoint, rw.ObjectName)
token, err := token(ctx, rw.XMLMultipartServiceAccount, tokenGetter, jsonCredentialsGetter)
if err != nil {
return nil, fmt.Errorf("failed to fetch auth token, err: %w", err)
}
w := &MultipartWriter{
bucket: rw.BucketHandle,
objectName: rw.ObjectName,
fileType: rw.Metadata["X-Backup-Type"],
token: token,
httpClient: newClient(10*time.Minute, defaultTransport()),
baseURL: baseURL,
storageClass: rw.StorageClass,
partSizeBytes: rw.ChunkSizeMb * 1024 * 1024,
partNum: 1,
maxRetries: rw.MaxRetries,
retryBackoffInitial: rw.RetryBackoffInitial,
retryBackoffMax: rw.RetryBackoffMax,
retryBackoffMultiplier: rw.RetryBackoffMultiplier,
mu: &sync.Mutex{},
parts: make(map[int64]objectPart),
workers: make([]*uploadWorker, rw.XMLMultipartWorkers),
idleWorkers: make(chan *uploadWorker, rw.XMLMultipartWorkers),
customTime: rw.CustomTime,
retentionMode: rw.ObjectRetentionMode,
retentionTime: rw.ObjectRetentionTime,
}
if w.uploadID, err = w.initMultipartUpload(); err != nil {
return nil, fmt.Errorf("failed to init multipart upload, err: %w", err)
}
// Each worker needs a dedicated transport to prevent throttling.
for i := 0; i < int(rw.XMLMultipartWorkers); i++ {
w.workers[i] = &uploadWorker{
w: w,
httpClient: newClient(10*time.Minute, defaultTransport()),
buffer: make([]byte, w.partSizeBytes),
}
w.idleWorkers <- w.workers[i]
}
return w, nil
}
// Write buffers a full part then asynchronously sends the data.
func (w *MultipartWriter) Write(p []byte) (int, error) {
bytesWritten := 0
for bytesWritten < len(p) {
if w.currentWorker == nil {
w.currentWorker = <-w.idleWorkers
}
if w.uploadErr != nil {
w.abortMultipartUpload()
return 0, w.uploadErr
}
n := copy(w.currentWorker.buffer[w.currentWorker.offset:], p[bytesWritten:])
bytesWritten += n
w.currentWorker.offset += int64(n)
if w.currentWorker.offset >= w.partSizeBytes {
go w.currentWorker.uploadPartAsync(w.partNum)
w.partNum++
w.currentWorker = nil
}
}
return bytesWritten, nil
}
// Close waits for all transfers to complete then generates the final object.
func (w *MultipartWriter) Close() error {
if w.currentWorker != nil {
go w.currentWorker.uploadPartAsync(w.partNum)
}
for i := 0; i < len(w.workers); i++ {
<-w.idleWorkers
}
if w.uploadErr != nil {
w.abortMultipartUpload()
return w.uploadErr
}
if err := w.completeMultipartUpload(); err != nil {
w.abortMultipartUpload()
return err
}
return nil
}
// initMultipartUpload starts the multipart upload and returns the upload ID.
func (w *MultipartWriter) initMultipartUpload() (string, error) {
initialPostURL := fmt.Sprintf("%s?uploads", w.baseURL)
req, err := http.NewRequest("POST", initialPostURL, nil)
if err != nil {
return "", fmt.Errorf("failed to create request, err: %w", err)
}
req.Header.Add("Content-Length", "0")
req.Header.Add("Date", time.Now().Format(http.TimeFormat))
req.Header.Add("Content-Type", "application/octet-stream")
if w.storageClass != "" {
req.Header.Add("x-goog-storage-class", w.storageClass)
}
w.token.SetAuthHeader(req)
resp, err := w.httpClient.Do(req)
defer googleapi.CloseBody(resp)
if err != nil {
return "", fmt.Errorf("failed to init multipart upload, err: %w", err)
}
if err := checkResponse(resp); err != nil {
return "", fmt.Errorf("multipart upload failed, err: %w", err)
}
parsedResult := &struct {
UploadID string `xml:"UploadId"`
}{}
if err := xml.NewDecoder(resp.Body).Decode(parsedResult); err != nil {
return "", fmt.Errorf("failed to decode multipart upload result, err: %w", err)
}
return parsedResult.UploadID, nil
}
// abortMultipartUpload aborts the upload to free resources from the bucket.
func (w *MultipartWriter) abortMultipartUpload() error {
log.Logger.Infow("Aborting multipart upload", "object", w.objectName, "uploadID", w.uploadID)
abortURL := fmt.Sprintf("%s?uploadId=%s", w.baseURL, w.uploadID)
req, err := http.NewRequest("DELETE", abortURL, nil)
if err != nil {
log.Logger.Errorw("Failed to create request for abort", "object", w.objectName, "uploadID", w.uploadID, "err", err)
return err
}
req.Header.Add("Content-Length", "0")
req.Header.Add("Date", time.Now().Format(http.TimeFormat))
w.token.SetAuthHeader(req)
resp, err := w.httpClient.Do(req)
defer googleapi.CloseBody(resp)
if err != nil || checkResponse(resp) != nil {
log.Logger.Errorw("Failed to abort multipart upload.", "object", w.objectName, "uploadID", w.uploadID, "err", err, "resp", resp)
return err
}
log.Logger.Infow("Successfully aborted multipart upload.", "object", w.objectName, "uploadID", w.uploadID)
return nil
}
// completeMultipartUpload sends the final POST request to complete the
// multipart upload, then writes final metadata to the object in the bucket.
func (w *MultipartWriter) completeMultipartUpload() error {
bodyXML, err := completeMultipartUploadXML(w.parts)
if err != nil {
return fmt.Errorf("failed to build complete multipart upload XML for %v, err: %w", w.objectName, err)
}
completeURL := fmt.Sprintf("%s?uploadId=%s", w.baseURL, w.uploadID)
req, err := http.NewRequest("POST", completeURL, strings.NewReader(bodyXML))
if err != nil {
return err
}
req.Header.Add("Content-Length", fmt.Sprintf("%v", len(bodyXML)))
req.Header.Add("Date", time.Now().Format(http.TimeFormat))
req.Header.Add("Content-Type", "application/xml")
w.token.SetAuthHeader(req)
resp, err := w.httpClient.Do(req)
defer googleapi.CloseBody(resp)
if err != nil {
return err
}
if err := checkResponse(resp); err != nil {
return err
}
// XML headers will force this key to be lowercase, set it after the upload.
update := storage.ObjectAttrsToUpdate{
Metadata: map[string]string{"X-Backup-Type": w.fileType},
CustomTime: w.customTime,
}
if w.retentionMode != "" {
update.Retention = &storage.ObjectRetention{
Mode: w.retentionMode,
RetainUntil: w.retentionTime,
}
}
log.Logger.Infow("Updating object attrs", "object", w.objectName, "update", update)
if _, err := w.bucket.Object(w.objectName).Update(req.Context(), update); err != nil {
return err
}
return nil
}
// completeMultipartUploadXML creates the XML for the final POST request.
// All parts and their ETag data must be included.
func completeMultipartUploadXML(parts map[int64]objectPart) (string, error) {
upload := struct {
XMLName xml.Name `xml:"CompleteMultipartUpload"`
Parts []objectPart `xml:"Part"`
}{}
// Order parts in the XML request.
upload.Parts = make([]objectPart, 0, len(parts))
for partNum := int64(1); partNum <= int64(len(parts)); partNum++ {
part, ok := parts[partNum]
if !ok {
return "", fmt.Errorf("part %v not contained in parts", partNum)
}
upload.Parts = append(upload.Parts, part)
}
xmlStr := strings.Builder{}
encoder := xml.NewEncoder(&xmlStr)
encoder.Indent("", " ")
if err := encoder.Encode(upload); err != nil {
return "", err
}
return xmlStr.String(), nil
}
// checkResponse verifies the response of http commands, returning any errors.
func checkResponse(resp *http.Response) error {
if resp.StatusCode == http.StatusOK || resp.StatusCode == http.StatusNoContent {
return nil
}
errStr := http.StatusText(resp.StatusCode)
debugID := resp.Header.Get("x-guploader-uploadid")
if resp.Body != nil {
body, readErr := io.ReadAll(resp.Body)
if readErr != nil {
return fmt.Errorf("%w (failed to read response body); %s; x-guploader-uploadid=%s", readErr, errStr, debugID)
}
if bodyStr := string(body); bodyStr != "" {
errStr = bodyStr
}
}
return fmt.Errorf("%s; x-guploader-uploadid=%s", errStr, debugID)
}
// token fetches a token with default, service account or workload identity
// federation credentials.
func token(ctx context.Context, serviceAccount string, tokenGetter DefaultTokenGetter, jsonCredentialsGetter JSONCredentialsGetter) (*oauth2.Token, error) {
if serviceAccount != "" {
serviceAccountBytes, err := os.ReadFile(serviceAccount)
if err != nil {
return nil, fmt.Errorf("failed to read service account file, err: %w", err)
}
cred, err := jsonCredentialsGetter(ctx, serviceAccountBytes, tokenScope)
if err != nil {
return nil, err
}
return cred.TokenSource.Token()
}
// TODO: Use token functions from go oauth2 libraries.
tokenSource, err := tokenGetter(ctx, tokenScope)
if err != nil {
return nil, err
}
return tokenSource.Token()
}
// uploadPartAsync continues upload attempts until a success,
// or the upload fails too many times.
func (uw *uploadWorker) uploadPartAsync(partNum int64) {
backoff := backoff(uw.w.retryBackoffInitial, uw.w.retryBackoffMax, uw.w.retryBackoffMultiplier)
for {
if err := uw.uploadPart(partNum); err != nil {
uw.numRetries++
if uw.numRetries > uw.w.maxRetries {
log.Logger.Errorw("Max retries exceeded, cancelling operation.", "partNum", partNum, "numRetries", uw.numRetries, "maxRetries", uw.w.maxRetries, "objectName", uw.w.objectName, "error", err)
uw.w.mu.Lock()
uw.w.uploadErr = fmt.Errorf("failed to upload part %v too many times, err: %w", partNum, err)
uw.w.mu.Unlock()
uw.w.idleWorkers <- uw
return
}
log.Logger.Infow("Failed to upload data to Google Cloud Storage, retrying.", "partNum", partNum, "numRetries", uw.numRetries, "maxRetries", uw.w.maxRetries, "objectName", uw.w.objectName, "error", err)
time.Sleep(backoff.Pause())
continue
}
uw.offset = 0
uw.numRetries = 0
uw.w.idleWorkers <- uw
return
}
}
// uploadPart uploads a part to the ongoing multipart upload.
func (uw *uploadWorker) uploadPart(partNum int64) error {
data := uw.buffer[:uw.offset]
url := fmt.Sprintf("%s?uploadId=%s&partNumber=%v", uw.w.baseURL, uw.w.uploadID, partNum)
req, err := http.NewRequest("PUT", url, bytes.NewReader(data))
if err != nil {
return err
}
req.Header.Add("Content-Length", fmt.Sprintf("%d", len(data)))
req.Header.Add("Date", time.Now().Format(http.TimeFormat))
uw.w.token.SetAuthHeader(req)
resp, err := uw.httpClient.Do(req)
defer googleapi.CloseBody(resp)
if err != nil {
return err
}
if err := checkResponse(resp); err != nil {
return err
}
etag := resp.Header.Get("ETag")
if etag == "" {
respStr := &strings.Builder{}
resp.Write(respStr)
return fmt.Errorf("uploadPart did not return in the expected format. Unable to get the ETag for part %v. Resp: %v", partNum, respStr.String())
}
uw.w.mu.Lock()
defer uw.w.mu.Unlock()
uw.w.parts[partNum] = objectPart{PartNumber: partNum, ETag: etag}
return nil
}