AsyncSocket::WriteResult AsyncSSLSocket::performWrite()

in folly/io/async/AsyncSSLSocket.cpp [1660:1850]


AsyncSocket::WriteResult AsyncSSLSocket::performWrite(
    const iovec* vec,
    uint32_t count,
    WriteFlags flags,
    uint32_t* countWritten,
    uint32_t* partialWritten) {
  if (sslState_ == STATE_UNENCRYPTED) {
    return AsyncSocket::performWrite(
        vec, count, flags, countWritten, partialWritten);
  }
  if (sslState_ != STATE_ESTABLISHED) {
    LOG(ERROR) << "AsyncSSLSocket(fd=" << fd_ << ", state=" << int(state_)
               << ", sslState=" << sslState_ << ", events=" << eventFlags_
               << "): "
               << "TODO: AsyncSSLSocket currently does not support calling "
               << "write() before the handshake has fully completed";
    return WriteResult(
        WRITE_ERROR, std::make_unique<SSLException>(SSLError::EARLY_WRITE));
  }

  // Declare a buffer used to hold small write requests.  It could point to a
  // memory block either on stack or on heap. If it is on heap, we release it
  // manually when scope exits
  char* combinedBuf{nullptr};
  SCOPE_EXIT {
    // Note, always keep this check consistent with what we do below
    if (combinedBuf != nullptr && minWriteSize_ > MAX_STACK_BUF_SIZE) {
      delete[] combinedBuf;
    }
  };

  *countWritten = 0;
  *partialWritten = 0;
  ssize_t totalWritten = 0;
  size_t bytesStolenFromNextBuffer = 0;
  for (uint32_t i = 0; i < count; i++) {
    const iovec* v = vec + i;
    size_t offset = bytesStolenFromNextBuffer;
    bytesStolenFromNextBuffer = 0;
    size_t len = v->iov_len - offset;
    const void* buf;
    if (len == 0) {
      (*countWritten)++;
      continue;
    }
    buf = ((const char*)v->iov_base) + offset;

    ssize_t bytes;
    uint32_t buffersStolen = 0;
    auto sslWriteBuf = buf;
    if ((len < minWriteSize_) && ((i + 1) < count)) {
      // Combine this buffer with part or all of the next buffers in
      // order to avoid really small-grained calls to SSL_write().
      // Each call to SSL_write() produces a separate record in
      // the egress SSL stream, and we've found that some low-end
      // mobile clients can't handle receiving an HTTP response
      // header and the first part of the response body in two
      // separate SSL records (even if those two records are in
      // the same TCP packet).

      if (combinedBuf == nullptr) {
        if (minWriteSize_ > MAX_STACK_BUF_SIZE) {
          // Allocate the buffer on heap
          combinedBuf = new char[minWriteSize_];
        } else {
          // Allocate the buffer on stack
          combinedBuf = (char*)alloca(minWriteSize_);
        }
      }
      assert(combinedBuf != nullptr);
      sslWriteBuf = combinedBuf;

      memcpy(combinedBuf, buf, len);
      do {
        // INVARIANT: i + buffersStolen == complete chunks serialized
        uint32_t nextIndex = i + buffersStolen + 1;
        bytesStolenFromNextBuffer =
            std::min(vec[nextIndex].iov_len, minWriteSize_ - len);
        if (bytesStolenFromNextBuffer > 0) {
          assert(vec[nextIndex].iov_base != nullptr);
          ::memcpy(
              combinedBuf + len,
              vec[nextIndex].iov_base,
              bytesStolenFromNextBuffer);
        }
        len += bytesStolenFromNextBuffer;
        if (bytesStolenFromNextBuffer < vec[nextIndex].iov_len) {
          // couldn't steal the whole buffer
          break;
        } else {
          bytesStolenFromNextBuffer = 0;
          buffersStolen++;
        }
      } while ((i + buffersStolen + 1) < count && (len < minWriteSize_));
    }

    // Advance any empty buffers immediately after.
    if (bytesStolenFromNextBuffer == 0) {
      while ((i + buffersStolen + 1) < count &&
             vec[i + buffersStolen + 1].iov_len == 0) {
        buffersStolen++;
      }
    }

    // From here, the write flow is as follows:
    //   - sslWriteImpl calls SSL_write, which encrypts the passed buffer.
    //   - SSL_write calls AsyncSSLSocket::bioWrite with the encrypted buffer.
    //   - AsyncSSLSocket::bioWrite calls AsyncSocket::sendSocketMessage(...).
    //
    // When sendSocketMessage calls sendMsg, WriteFlags are transformed into
    // ancillary data and/or sendMsg flags. If WriteFlag::EOR is in flags and
    // trackEor_ is set, then we should ensure that MSG_EOR is only passed to
    // sendmsg when the final byte of the orginally passed in buffer is being
    // written. Since the buffer originally passed to performWrite may be split
    // up and written over multiple calls to sendmsg, we have to take care to
    // unset the EOR flag if it was included in the WriteFlags passed in and
    // we're writing a buffer that does _not_ contain the final byte of the
    // orignally passed buffer.
    //
    // We handle EOR as follows:
    //   - We set currWriteFlags_ to the passed in WriteFlags.
    //   - If sslWriteBuf does NOT contain the last byte of the passed in iovec,
    //     then we set currBytesToFinalByte_ to folly::none. In bioWrite, we
    //     unset WriteFlags::EOR if it is set in currWriteFlags_.
    //   - If sslWriteBuf DOES contain the last byte of the passed in iovec,
    //     then we set bytesToFinalByte_ to int(len). In bioWrite, if the length
    //     of the passed in buffer >= currBytesToFinalByte_, then we leave the
    //     flags in currWriteFlags_ alone.
    //
    // What about timestamp flags?
    //   - We don't do any special handling for timestamping flags.
    //   - This may mean that more timestamps than necessary get generated, but
    //     that's OK; you already have to deal with that for timestamping due to
    //     the possibility of partial writes.
    //   - MSG_EOR used to be used for timestamping, but hasn't been for years.
    //
    // Finally, why even care about MSG_EOR, if not for timestamping?
    //   - If set, it is marked in the corresponding tcp_skb_cb; this can be
    //     useful when debugging.
    //   - The kernel uses it to decide whether socket buffers can be collapsed
    //     together (see tcp_skb_can_collapse_to).
    currWriteFlags_ = flags;
    uint32_t iovecWrittenToSslWriteBuf = i + buffersStolen + 1;
    CHECK_LE(iovecWrittenToSslWriteBuf, count);
    if (iovecWrittenToSslWriteBuf == count) { // last byte is in sslWriteBuf
      currBytesToFinalByte_ = len; // length of current buffer
    } else { // there are still remaining buffers / iovec to write
      currBytesToFinalByte_ = folly::none;
      currWriteFlags_ |= WriteFlags::CORK;
    }

    bytes = sslWriteImpl(ssl_.get(), sslWriteBuf, int(len));
    if (bytes <= 0) {
      int error = sslGetErrorImpl(ssl_.get(), int(bytes));
      if (error == SSL_ERROR_WANT_WRITE) {
        // The entire buffer needs to be passed in again, so *partialWritten
        // is set to the original offset where we started for this call to
        // performWrite(); see SSL_ERROR_WANT_WRITE documentation for details.
        //
        // The caller will register for write event if not already.
        *partialWritten = uint32_t(offset);
        return WriteResult(totalWritten);
      }
      return interpretSSLError(int(bytes), error);
    }

    totalWritten += bytes;
    appBytesWritten_ += bytes;

    if (bytes == (ssize_t)len) {
      // The full iovec is written.
      (*countWritten) += 1 + buffersStolen;
      i += buffersStolen;
      // continue
    } else {
      bytes += offset; // adjust bytes to account for all of v
      while (bytes >= (ssize_t)v->iov_len) {
        // We combined this buf with part or all of the next one, and
        // we managed to write all of this buf but not all of the bytes
        // from the next one that we'd hoped to write.
        bytes -= v->iov_len;
        (*countWritten)++;
        v = &(vec[++i]);
      }
      *partialWritten = uint32_t(bytes);
      return WriteResult(totalWritten);
    }
  }

  return WriteResult(totalWritten);
}