static void MultiPartCopy()

in tensorflow_io/core/filesystems/s3/s3_filesystem.cc [790:924]


static void MultiPartCopy(const Aws::String& source,
                          const Aws::String& bucket_dst,
                          const Aws::String& object_dst, const size_t num_parts,
                          const uint64_t file_size, S3File* s3_file,
                          TF_Status* status) {
  TF_VLog(1, "MultiPartCopy from %s to %s/%s\n", source.c_str(),
          bucket_dst.c_str(), object_dst.c_str());
  Aws::S3::Model::CreateMultipartUploadRequest create_multipart_upload_request;
  create_multipart_upload_request.WithBucket(bucket_dst).WithKey(object_dst);

  GetS3Client(s3_file);

  auto create_multipart_upload_outcome =
      s3_file->s3_client->CreateMultipartUpload(
          create_multipart_upload_request);
  if (!create_multipart_upload_outcome.IsSuccess())
    return TF_SetStatusFromAWSError(create_multipart_upload_outcome.GetError(),
                                    status);

  auto upload_id = create_multipart_upload_outcome.GetResult().GetUploadId();

  int num_finished_parts = 0;
  // Keep track of `Outcome` of each upload part.
  Aws::Vector<EtagOutcome> etag_outcomes(num_parts);
  // Mutex which protects access of the part_states map.
  absl::Mutex multi_part_copy_mutex;
  // Condition variable to be used with above mutex for synchronization.
  absl::CondVar multi_part_copy_cv;

  auto chunk_size =
      s3_file->multi_part_chunk_sizes[Aws::Transfer::TransferDirection::UPLOAD];

  TF_VLog(1, "Copying from %s in %u parts of size %u each\n", source.c_str(),
          num_parts, chunk_size);
  size_t retries = 0;
  while (retries++ < 3) {
    // Queue up parts.
    for (auto part_number = 0; part_number < num_parts; ++part_number) {
      if (etag_outcomes[part_number].IsSuccess()) continue;
      uint64_t start_pos = part_number * chunk_size;
      uint64_t end_pos = start_pos + chunk_size - 1;
      if (end_pos >= file_size) end_pos = file_size - 1;

      Aws::String range =
          absl::StrCat("bytes=", start_pos, "-", end_pos).c_str();
      Aws::S3::Model::UploadPartCopyRequest upload_part_copy_request;
      upload_part_copy_request.WithBucket(bucket_dst)
          .WithKey(object_dst)
          .WithCopySource(source)
          .WithCopySourceRange(range)
          // S3 API partNumber starts from 1.
          .WithPartNumber(part_number + 1)
          .WithUploadId(upload_id);

      auto multi_part_context =
          Aws::MakeShared<MultipartCopyAsyncContext>("MultiPartCopyContext");
      multi_part_context->part_number = part_number;
      multi_part_context->num_finished_parts = &num_finished_parts;
      multi_part_context->etag_outcomes = &etag_outcomes;
      multi_part_context->multi_part_copy_mutex = &multi_part_copy_mutex;
      multi_part_context->multi_part_copy_cv = &multi_part_copy_cv;
      auto callback =
          [](const Aws::S3::S3Client* client,
             const Aws::S3::Model::UploadPartCopyRequest& request,
             const Aws::S3::Model::UploadPartCopyOutcome& outcome,
             const std::shared_ptr<const Aws::Client::AsyncCallerContext>&
                 context) {
            auto multipart_context =
                std::static_pointer_cast<const MultipartCopyAsyncContext>(
                    context);
            MultiPartCopyCallback(request, outcome, multipart_context);
          };

      std::shared_ptr<const Aws::Client::AsyncCallerContext> context =
          multi_part_context;
      s3_file->s3_client->UploadPartCopyAsync(upload_part_copy_request,
                                              callback, context);
    }
    // Wait till they finish.
    {
      absl::MutexLock l(&multi_part_copy_mutex);
      // Wait on the mutex until notify is called then check the finished parts
      // as there could be false notifications.
      while (num_finished_parts != num_parts) {
        multi_part_copy_cv.Wait(&multi_part_copy_mutex);
      }
    }
    // check if there was any error for any part.
    for (auto part_number = 0; part_number < num_parts; ++part_number) {
      if (!etag_outcomes[part_number].IsSuccess()) {
        if (retries >= 3) {
          AbortMultiPartCopy(bucket_dst, object_dst, upload_id, s3_file,
                             status);
          if (TF_GetCode(status) != TF_OK) return;
          return TF_SetStatusFromAWSError(etag_outcomes[part_number].GetError(),
                                          status);
        } else {
          // Retry.
          TF_Log(TF_ERROR,
                 "Retrying failed copy of part %u due to an error with S3\n",
                 part_number);
          num_finished_parts--;
        }
      }
    }
  }

  Aws::S3::Model::CompletedMultipartUpload completed_multipart_upload;
  // If there was an error still in any part, it would abort and return in the
  // above loop. We set the eTag of completed parts to the final
  // `completed_multipart_upload`. Note these parts have to be added in order.
  for (int part_number = 0; part_number < num_parts; ++part_number) {
    Aws::S3::Model::CompletedPart completed_part;
    completed_part.SetPartNumber(part_number + 1);
    completed_part.SetETag(etag_outcomes[part_number].GetResult());
    completed_multipart_upload.AddParts(completed_part);
  }

  Aws::S3::Model::CompleteMultipartUploadRequest
      complete_multipart_upload_request;
  complete_multipart_upload_request.WithBucket(bucket_dst)
      .WithKey(object_dst)
      .WithUploadId(upload_id)
      .WithMultipartUpload(completed_multipart_upload);
  auto complete_multipart_upload_outcome =
      s3_file->s3_client->CompleteMultipartUpload(
          complete_multipart_upload_request);
  if (!complete_multipart_upload_outcome.IsSuccess())
    AbortMultiPartCopy(bucket_dst, object_dst, upload_id, s3_file, status);
  else
    return TF_SetStatus(status, TF_OK, "");
  if (TF_GetCode(status) == TF_OK)
    return TF_SetStatusFromAWSError(
        complete_multipart_upload_outcome.GetError(), status);
};