in tensorflow_networking/seastar/seastar_client_tag.cc [63:174]
void InitSeastarClientTag(protobuf::Message* request,
SeastarTensorResponse* response, StatusCallback done,
SeastarClientTag* tag, CallOptions* call_opts) {
tag->req_body_buf_.len_ = request->ByteSize();
tag->req_body_buf_.data_ = new char[tag->req_body_buf_.len_]();
request->SerializeToArray(tag->req_body_buf_.data_, tag->req_body_buf_.len_);
tag->req_header_buf_.len_ = SeastarClientTag::HEADER_SIZE;
tag->req_header_buf_.data_ = new char[SeastarClientTag::HEADER_SIZE];
memcpy(tag->req_header_buf_.data_, "DEADBEEF", 8);
memcpy(tag->req_header_buf_.data_ + 8, &tag, 8);
memcpy(tag->req_header_buf_.data_ + 16, &tag->method_, 4);
// Ignore the status segment in request
// memcpy(tag->req_header_buf_.data_ + 20, &tag->status_, 2);
memcpy(tag->req_header_buf_.data_ + 24, &tag->req_body_buf_.len_, 8);
ParseMessageCallback wrapper_parse_message = [request, response, tag]() {
SeastarMessage sm;
SeastarMessage::DeserializeMessage(&sm, tag->resp_message_buf_.data_);
response->SetIsDead(sm.is_dead_);
response->SetDataType(sm.data_type_);
bool can_memcpy = DataTypeCanUseMemcpy(sm.data_type_);
if (can_memcpy) {
if (response->GetDevice()->tensorflow_gpu_device_info() &&
(!response->GetOnHost())) {
AllocatorAttributes alloc_attrs;
alloc_attrs.set_gpu_compatible(true);
alloc_attrs.set_on_host(true);
Allocator* alloc = response->GetDevice()->GetAllocator(alloc_attrs);
Tensor cpu_copy(alloc, sm.data_type_, sm.tensor_shape_);
tag->resp_tensor_buf_.data_ =
reinterpret_cast<char*>(DMAHelper::base(&cpu_copy));
tag->resp_tensor_buf_.len_ = sm.tensor_bytes_;
tag->resp_tensor_buf_.owned_ = false;
response->SetTensor(cpu_copy);
} else {
Tensor val(response->GetAlloc(), sm.data_type_, sm.tensor_shape_);
tag->resp_tensor_buf_.data_ =
reinterpret_cast<char*>(DMAHelper::base(&val));
tag->resp_tensor_buf_.len_ = sm.tensor_bytes_;
tag->resp_tensor_buf_.owned_ = false;
response->SetTensor(val);
}
} else {
tag->resp_tensor_buf_.len_ = sm.tensor_bytes_;
tag->resp_tensor_buf_.data_ = new char[tag->resp_tensor_buf_.len_]();
}
return Status();
};
tag->parse_message_ = std::move(wrapper_parse_message);
StatusCallback wrapper_done = std::bind(
[response, tag](StatusCallback done, const Status& s) {
if (!s.ok()) {
LOG(ERROR) << "wrapper_done, status not ok. status code=" << s.code()
<< ", err msg=" << s.error_message().c_str();
done(s);
delete tag;
return;
}
bool can_memcpy = DataTypeCanUseMemcpy(response->GetDataType());
if (can_memcpy) {
if (response->GetDevice()->tensorflow_gpu_device_info() &&
(!response->GetOnHost())) {
Tensor* gpu_copy =
new Tensor(response->GetAlloc(), response->GetTensor().dtype(),
response->GetTensor().shape());
DeviceContext* recv_dev_context = response->GetDevice()
->tensorflow_gpu_device_info()
->default_context;
recv_dev_context->CopyCPUTensorToDevice(
&response->GetTensor(), response->GetDevice(), gpu_copy,
[gpu_copy, response, done, tag](const Status& s) {
CHECK(s.ok()) << "copy tensor to gpu sync";
response->SetTensor(*gpu_copy);
done(s);
delete gpu_copy;
delete tag;
});
} else {
done(s);
delete tag;
}
} else {
// could not memcopy
ParseProtoUnlimited(&response->GetTensorProto(),
tag->resp_tensor_buf_.data_,
tag->resp_tensor_buf_.len_);
Tensor val;
Status status = response->GetDevice()->MakeTensorFromProto(
response->GetTensorProto(), response->GetAllocAttributes(), &val);
CHECK(status.ok()) << "make cpu tensor from proto.";
response->SetTensor(val);
done(status);
delete tag;
}
},
std::move(done), std::placeholders::_1);
tag->done_ = std::move(wrapper_done);
tag->call_opts_ = call_opts;
ProcessCallOptions(tag);
}