string RdmaMessage::CreateMessage()

in tensorflow_networking/verbs/rdma.cc [1274:1349]


string RdmaMessage::CreateMessage(const RdmaMessage& rm) {
  // Rdma Message format
  // type|name_size|name|step_id|request_index|remote_addr|rkey|is_dead|...
  //   1B|    2B   | 512|  8B   |     8B      |       8B  | 4B |    1B |...
  // ...|data_type|tensor_shape|tensor_bytes|error_status          |
  // ...|   XB    |    XB      |    8B      |size - 4B, proto - XB |
  //
  // ACK:             Imm-type: ACK
  // TENSOR_REQUEST:  Imm-type: MESSAGE
  //                  Fields: type, request_index, name, step_id, remote_addr,
  //                      rkey, is_dead, data_type, tensor_shape, tensor_bytes
  // META_DATA_UPDATE: Imm-type: MESSAGE
  //                  Fields: type, request_index, is_dead, data_type,
  //                      tensor_shape, tensor_bytes
  // TENSOR_RE_REQUST: Imm-type: MESSAGE
  //                  Fields: type, request_index, name, step_id, remote_addr,
  //                      rkey, is_dead, data_type, tensor_shape, tensor_bytes
  // ERROR_STATUS:    Imm-type: MESSAGE
  //                  Fields: type, request_index, name, step_id, error_status
  // Tensor content:  Imm-type: request_index
  size_t message_size = kMessageTotalBytes;
  char message[kMessageTotalBytes + kErrorStatusMaxSize];
  // type
  message[kTypeStartIndex] = static_cast<char>(rm.type_) & 0xff;
  // request index
  memcpy(&message[kRequestIndexStartIndex], &rm.request_index_,
         sizeof(rm.request_index_));
  // name, step_id, remote_addr, rkey
  if ((rm.type_ == RDMA_MESSAGE_TENSOR_REQUEST) ||
      (rm.type_ == RDMA_MESSAGE_TENSOR_RE_REQUEST)) {
    memcpy(&message[kNameSizeStartIndex], &rm.name_size_,
           sizeof(rm.name_size_));
    memcpy(&message[kNameStartIndex], rm.name_.data(), rm.name_.size());
    memcpy(&message[kRemoteAddrStartIndex], &rm.remote_addr_,
           sizeof(rm.remote_addr_));
    memcpy(&message[kRkeyStartIndex], &rm.rkey_, sizeof(rm.rkey_));
    memcpy(&message[kStepIdStartIndex], &rm.step_id_, sizeof(rm.step_id_));
  }
  // is_dead, data_type, tensor_shape, tensor_bytes
  if ((rm.type_ == RDMA_MESSAGE_TENSOR_REQUEST) ||
      (rm.type_ == RDMA_MESSAGE_META_DATA_UPDATE) ||
      (rm.type_ == RDMA_MESSAGE_TENSOR_RE_REQUEST)) {
    memcpy(&message[kIsDeadStartIndex], &rm.is_dead_, sizeof(rm.is_dead_));

    memcpy(&message[kDataTypeStartIndex], &rm.data_type_,
           sizeof(rm.data_type_));
    memcpy(&message[kTensorShapeStartIndex], &rm.tensor_shape_,
           sizeof(rm.tensor_shape_));
    memcpy(&message[kTensorBytesStartIndex], &rm.tensor_bytes_,
           sizeof(rm.tensor_bytes_));
  }
// checksum
#ifdef RDMA_DATA_VALIDATION
  memcpy(&message[kChecksumStartIndex], &rm.checksum_, sizeof(rm.checksum_));
#endif
  // error status
  if (rm.type_ == RDMA_MESSAGE_ERROR_STATUS) {
    ::grpc::Status gs = ToGrpcStatus(rm.status_);
    ErrorStatusProto gsProto;
    gsProto.set_error_code(gs.error_code());
    gsProto.set_error_message(gs.error_message());
    gsProto.set_error_details(gs.error_details());
    uint32_t gsProtoSize = gsProto.ByteSize();
    if (gsProtoSize + 4 > kErrorStatusMaxSize) {
      LOG(ERROR) << "Error status (" << gsProtoSize + 4 << " bytes) "
                 << "is too big to fit in RDMA message (" << kErrorStatusMaxSize
                 << " bytes). Truncated.";
      gsProtoSize = kErrorStatusMaxSize - 4;
    }
    uint32_t* proto_size = (uint32_t*)&message[kErrorStatusStartIndex];
    *proto_size = gsProtoSize;
    gsProto.SerializeToArray(&message[kErrorStatusStartIndex + 4], gsProtoSize);
    message_size += gsProtoSize + 4;
  }
  return string(message, message_size);
}