in tensorflow_networking/verbs/rdma.cc [1274:1349]
string RdmaMessage::CreateMessage(const RdmaMessage& rm) {
// Rdma Message format
// type|name_size|name|step_id|request_index|remote_addr|rkey|is_dead|...
// 1B| 2B | 512| 8B | 8B | 8B | 4B | 1B |...
// ...|data_type|tensor_shape|tensor_bytes|error_status |
// ...| XB | XB | 8B |size - 4B, proto - XB |
//
// ACK: Imm-type: ACK
// TENSOR_REQUEST: Imm-type: MESSAGE
// Fields: type, request_index, name, step_id, remote_addr,
// rkey, is_dead, data_type, tensor_shape, tensor_bytes
// META_DATA_UPDATE: Imm-type: MESSAGE
// Fields: type, request_index, is_dead, data_type,
// tensor_shape, tensor_bytes
// TENSOR_RE_REQUST: Imm-type: MESSAGE
// Fields: type, request_index, name, step_id, remote_addr,
// rkey, is_dead, data_type, tensor_shape, tensor_bytes
// ERROR_STATUS: Imm-type: MESSAGE
// Fields: type, request_index, name, step_id, error_status
// Tensor content: Imm-type: request_index
size_t message_size = kMessageTotalBytes;
char message[kMessageTotalBytes + kErrorStatusMaxSize];
// type
message[kTypeStartIndex] = static_cast<char>(rm.type_) & 0xff;
// request index
memcpy(&message[kRequestIndexStartIndex], &rm.request_index_,
sizeof(rm.request_index_));
// name, step_id, remote_addr, rkey
if ((rm.type_ == RDMA_MESSAGE_TENSOR_REQUEST) ||
(rm.type_ == RDMA_MESSAGE_TENSOR_RE_REQUEST)) {
memcpy(&message[kNameSizeStartIndex], &rm.name_size_,
sizeof(rm.name_size_));
memcpy(&message[kNameStartIndex], rm.name_.data(), rm.name_.size());
memcpy(&message[kRemoteAddrStartIndex], &rm.remote_addr_,
sizeof(rm.remote_addr_));
memcpy(&message[kRkeyStartIndex], &rm.rkey_, sizeof(rm.rkey_));
memcpy(&message[kStepIdStartIndex], &rm.step_id_, sizeof(rm.step_id_));
}
// is_dead, data_type, tensor_shape, tensor_bytes
if ((rm.type_ == RDMA_MESSAGE_TENSOR_REQUEST) ||
(rm.type_ == RDMA_MESSAGE_META_DATA_UPDATE) ||
(rm.type_ == RDMA_MESSAGE_TENSOR_RE_REQUEST)) {
memcpy(&message[kIsDeadStartIndex], &rm.is_dead_, sizeof(rm.is_dead_));
memcpy(&message[kDataTypeStartIndex], &rm.data_type_,
sizeof(rm.data_type_));
memcpy(&message[kTensorShapeStartIndex], &rm.tensor_shape_,
sizeof(rm.tensor_shape_));
memcpy(&message[kTensorBytesStartIndex], &rm.tensor_bytes_,
sizeof(rm.tensor_bytes_));
}
// checksum
#ifdef RDMA_DATA_VALIDATION
memcpy(&message[kChecksumStartIndex], &rm.checksum_, sizeof(rm.checksum_));
#endif
// error status
if (rm.type_ == RDMA_MESSAGE_ERROR_STATUS) {
::grpc::Status gs = ToGrpcStatus(rm.status_);
ErrorStatusProto gsProto;
gsProto.set_error_code(gs.error_code());
gsProto.set_error_message(gs.error_message());
gsProto.set_error_details(gs.error_details());
uint32_t gsProtoSize = gsProto.ByteSize();
if (gsProtoSize + 4 > kErrorStatusMaxSize) {
LOG(ERROR) << "Error status (" << gsProtoSize + 4 << " bytes) "
<< "is too big to fit in RDMA message (" << kErrorStatusMaxSize
<< " bytes). Truncated.";
gsProtoSize = kErrorStatusMaxSize - 4;
}
uint32_t* proto_size = (uint32_t*)&message[kErrorStatusStartIndex];
*proto_size = gsProtoSize;
gsProto.SerializeToArray(&message[kErrorStatusStartIndex + 4], gsProtoSize);
message_size += gsProtoSize + 4;
}
return string(message, message_size);
}