func()

in registry/storage/driver/s3-aws/v2/s3.go [1333:1528]


func (w *writer) Write(p []byte) (int, error) {
	ctx := context.Background()

	switch {
	case w.closed:
		return 0, storagedriver.ErrAlreadyClosed
	case w.committed:
		return 0, storagedriver.ErrAlreadyCommited
	case w.canceled:
		return 0, storagedriver.ErrAlreadyCanceled
	}

	// If the length of the last written part is different than chunkSize,
	// we need to make a new multipart upload to even things out.
	if len(w.parts) > 0 && *w.parts[len(w.parts)-1].Size != w.chunkSize {
		var completedUploadedParts completedParts
		for _, part := range w.parts {
			completedUploadedParts = append(completedUploadedParts, types.CompletedPart{
				ETag:              part.ETag,
				PartNumber:        part.PartNumber,
				ChecksumCRC32:     part.ChecksumCRC32,
				ChecksumCRC32C:    part.ChecksumCRC32C,
				ChecksumCRC64NVME: part.ChecksumCRC64NVME,
				ChecksumSHA1:      part.ChecksumSHA1,
				ChecksumSHA256:    part.ChecksumSHA256,
			})
		}

		sort.Sort(completedUploadedParts)

		_, err := w.driver.S3.CompleteMultipartUpload(
			ctx,
			&s3.CompleteMultipartUploadInput{
				Bucket:   ptr.String(w.driver.Bucket),
				Key:      ptr.String(w.key),
				UploadId: ptr.String(w.uploadID),
				MultipartUpload: &types.CompletedMultipartUpload{
					Parts: completedUploadedParts,
				},
			})
		if err != nil {
			_, errIn := w.driver.S3.AbortMultipartUpload(
				ctx,
				&s3.AbortMultipartUploadInput{
					Bucket:   ptr.String(w.driver.Bucket),
					Key:      ptr.String(w.key),
					UploadId: ptr.String(w.uploadID),
				})
			if errIn != nil {
				return 0, fmt.Errorf("aborting upload failed while handling error %w: %w", err, errIn)
			}
			return 0, err
		}

		inputArgs := &s3.CreateMultipartUploadInput{
			Bucket:               ptr.String(w.driver.Bucket),
			Key:                  ptr.String(w.key),
			ContentType:          w.driver.getContentType(),
			ACL:                  w.driver.getACL(),
			ServerSideEncryption: w.driver.getEncryptionMode(),
			StorageClass:         w.driver.getStorageClass(),
		}
		if !w.checksumDisabled {
			inputArgs.ChecksumAlgorithm = w.checksumAlgorithm
		}
		resp, err := w.driver.S3.CreateMultipartUpload(ctx, inputArgs)
		if err != nil {
			return 0, err
		}
		w.uploadID = *resp.UploadId

		partCount := (w.size + w.chunkSize - 1) / w.chunkSize
		w.parts = make([]types.Part, 0, partCount)
		partsMutex := new(sync.Mutex)

		g, gctx := errgroup.WithContext(ctx)

		// Reduce the client/server exposure to long lived connections regardless of
		// how many requests per second are allowed.
		g.SetLimit(w.multipartCopyMaxConcurrency)

		for i := int64(0); i < partCount; i++ {
			g.Go(func() error {
				// Check if any other goroutine has failed
				select {
				case <-gctx.Done():
					return gctx.Err()
				default:
				}

				startByte := w.chunkSize * i
				endByte := startByte + w.chunkSize - 1
				if endByte > w.size-1 {
					endByte = w.size - 1
					// NOTE(prozlach): Special case when there is simply not
					// enough data for a full chunk. It handles both cases when
					// there is only one chunk and multiple chunks plus partial
					// a one. We just slurp in what we have and carry on with
					// the data passed to the Write() call.
					byteRange := fmt.Sprintf("bytes=%d-%d", startByte, endByte)
					resp, err := w.driver.S3.GetObject(
						gctx,
						&s3.GetObjectInput{
							Bucket: ptr.String(w.driver.Bucket),
							Key:    ptr.String(w.key),
							Range:  ptr.String(byteRange),
						})
					if err != nil {
						return fmt.Errorf("fetching object from backend during re-upload: %w", err)
					}
					defer resp.Body.Close()

					_, err = w.buffer.ReadFrom(resp.Body)
					if err != nil {
						return fmt.Errorf("reading remaining bytes during data re-upload: %w", err)
					}
					return nil
				}

				// Part numbers are positive integers in the range 1 <= n <= 10000
				partNumber := i + 1

				// Specify the byte range to copy. `CopySourceRange` factors
				// in both starting and ending byte.
				byteRange := fmt.Sprintf("bytes=%d-%d", startByte, endByte)

				copyPartResp, err := w.driver.S3.UploadPartCopy(
					gctx,
					&s3.UploadPartCopyInput{
						Bucket:          ptr.String(w.driver.Bucket),
						CopySource:      ptr.String(w.driver.Bucket + "/" + w.key),
						CopySourceRange: ptr.String(byteRange),
						Key:             ptr.String(w.key),
						PartNumber:      ptr.Int32(int32(partNumber)), // nolint: gosec // partNumber will always be a non-negative number
						UploadId:        resp.UploadId,
					})
				if err != nil {
					return fmt.Errorf("re-uploading chunk as multipart upload part: %w", err)
				}

				// NOTE(prozlach): We can't pre-allocate parts slice and then
				// address it with `i` if we want to handle the case where
				// there is not enough data for a single chunk. At the expense
				// of this small extra loop and extra mutex we get the same
				// path for both cases.
				partsMutex.Lock()
				for int64(len(w.parts))-1 < i {
					w.parts = append(w.parts, types.Part{})
				}

				// Add this part to our list
				w.parts[i] = types.Part{
					ETag:       copyPartResp.CopyPartResult.ETag,
					PartNumber: ptr.Int32(int32(partNumber)), // nolint: gosec // partNumber will always be a non-negative number
					// `CopySourceRange` factors in both starting and
					// ending byte, hence `+ 1`.
					Size: ptr.Int64(endByte - startByte + 1),
				}
				partsMutex.Unlock()

				return nil
			})
		}

		if err := g.Wait(); err != nil {
			return 0, err
		}
	}

	var n int64

	for len(p) > 0 {
		// If no parts are ready to write, fill up the first part
		if neededBytes := w.driver.ChunkSize - int64(w.buffer.Len()); neededBytes > 0 {
			if int64(len(p)) >= neededBytes {
				_, _ = w.buffer.Write(p[:neededBytes]) // err is always nil
				n += neededBytes
				p = p[neededBytes:]

				err := w.flush()
				// nolint: revive // max-control-nesting
				if err != nil {
					w.size += n
					return int(n), err // nolint: gosec // n is never going to be negative
				}
			} else {
				_, _ = w.buffer.Write(p)
				n += int64(len(p))
				p = nil
			}
		}
	}

	w.size += n
	return int(n), nil // nolint: gosec // n is never going to be negative
}