in compiler/cpp/src/thrift/generate/t_cpp_generator.cc [2503:3086]
void t_cpp_generator::generate_service_client(t_service* tservice, string style) {
string ifstyle;
if (style == "Cob") {
ifstyle = "CobCl";
}
std::ostream& out = (gen_templates_ ? f_service_tcc_ : f_service_);
string template_header, template_suffix, short_suffix, protocol_type, _this;
string const prot_factory_type = "::apache::thrift::protocol::TProtocolFactory";
if (gen_templates_) {
template_header = "template <class Protocol_>\n";
short_suffix = "T";
template_suffix = "T<Protocol_>";
protocol_type = "Protocol_";
_this = "this->";
} else {
protocol_type = "::apache::thrift::protocol::TProtocol";
}
string prot_ptr = "std::shared_ptr< " + protocol_type + ">";
string client_suffix = "Client" + template_suffix;
string if_suffix = "If";
if (style == "Cob") {
if_suffix += template_suffix;
}
string extends = "";
string extends_client = "";
if (tservice->get_extends() != nullptr) {
// TODO(simpkins): If gen_templates_ is enabled, we currently assume all
// parent services were also generated with templates enabled.
extends = type_name(tservice->get_extends());
extends_client = ", public " + extends + style + client_suffix;
}
// Generate the header portion
if (style == "Concurrent") {
f_header_ << "// The \'concurrent\' client is a thread safe client that correctly handles\n"
"// out of order responses. It is slower than the regular client, so should\n"
"// only be used when you need to share a connection among multiple threads\n";
}
f_header_ << template_header << "class " << service_name_ << style << "Client" << short_suffix
<< " : "
<< "virtual public " << service_name_ << ifstyle << if_suffix << extends_client << " {"
<< '\n' << " public:" << '\n';
indent_up();
if (style != "Cob") {
f_header_ << indent() << service_name_ << style << "Client" << short_suffix << "(" << prot_ptr
<< " prot";
if (style == "Concurrent") {
f_header_ << ", std::shared_ptr< ::apache::thrift::async::TConcurrentClientSyncInfo> sync";
}
f_header_ << ") ";
if (extends.empty()) {
if (style == "Concurrent") {
f_header_ << ": sync_(sync)" << '\n';
}
f_header_ << "{" << '\n';
f_header_ << indent() << " setProtocol" << short_suffix << "(prot);" << '\n' << indent()
<< "}" << '\n';
} else {
f_header_ << ":" << '\n';
f_header_ << indent() << " " << extends << style << client_suffix << "(prot, prot";
if (style == "Concurrent") {
f_header_ << ", sync";
}
f_header_ << ") {}" << '\n';
}
f_header_ << indent() << service_name_ << style << "Client" << short_suffix << "(" << prot_ptr
<< " iprot, " << prot_ptr << " oprot";
if (style == "Concurrent") {
f_header_ << ", std::shared_ptr< ::apache::thrift::async::TConcurrentClientSyncInfo> sync";
}
f_header_ << ") ";
if (extends.empty()) {
if (style == "Concurrent") {
f_header_ << ": sync_(sync)" << '\n';
}
f_header_ << "{" << '\n';
f_header_ << indent() << " setProtocol" << short_suffix << "(iprot,oprot);" << '\n'
<< indent() << "}" << '\n';
} else {
f_header_ << ":" << indent() << " " << extends << style << client_suffix
<< "(iprot, oprot";
if (style == "Concurrent") {
f_header_ << ", sync";
}
f_header_ << ") {}" << '\n';
}
// create the setProtocol methods
if (extends.empty()) {
f_header_ << " private:" << '\n';
// 1: one parameter
f_header_ << indent() << "void setProtocol" << short_suffix << "(" << prot_ptr << " prot) {"
<< '\n';
f_header_ << indent() << "setProtocol" << short_suffix << "(prot,prot);" << '\n';
f_header_ << indent() << "}" << '\n';
// 2: two parameter
f_header_ << indent() << "void setProtocol" << short_suffix << "(" << prot_ptr << " iprot, "
<< prot_ptr << " oprot) {" << '\n';
f_header_ << indent() << " piprot_=iprot;" << '\n' << indent() << " poprot_=oprot;" << '\n'
<< indent() << " iprot_ = iprot.get();" << '\n' << indent()
<< " oprot_ = oprot.get();" << '\n';
f_header_ << indent() << "}" << '\n';
f_header_ << " public:" << '\n';
}
// Generate getters for the protocols.
// Note that these are not currently templated for simplicity.
// TODO(simpkins): should they be templated?
f_header_ << indent()
<< "std::shared_ptr< ::apache::thrift::protocol::TProtocol> getInputProtocol() {"
<< '\n' << indent() << " return " << _this << "piprot_;" << '\n' << indent() << "}"
<< '\n';
f_header_ << indent()
<< "std::shared_ptr< ::apache::thrift::protocol::TProtocol> getOutputProtocol() {"
<< '\n' << indent() << " return " << _this << "poprot_;" << '\n' << indent() << "}"
<< '\n';
} else /* if (style == "Cob") */ {
f_header_ << indent() << service_name_ << style << "Client" << short_suffix << "("
<< "std::shared_ptr< ::apache::thrift::async::TAsyncChannel> channel, "
<< "::apache::thrift::protocol::TProtocolFactory* protocolFactory) :" << '\n';
if (extends.empty()) {
f_header_ << indent() << " channel_(channel)," << '\n' << indent()
<< " itrans_(new ::apache::thrift::transport::TMemoryBuffer())," << '\n'
<< indent() << " otrans_(new ::apache::thrift::transport::TMemoryBuffer()),"
<< '\n';
if (gen_templates_) {
// TProtocolFactory classes return generic TProtocol pointers.
// We have to dynamic cast to the Protocol_ type we are expecting.
f_header_ << indent() << " piprot_(::std::dynamic_pointer_cast<Protocol_>("
<< "protocolFactory->getProtocol(itrans_)))," << '\n' << indent()
<< " poprot_(::std::dynamic_pointer_cast<Protocol_>("
<< "protocolFactory->getProtocol(otrans_))) {" << '\n';
// Throw a TException if either dynamic cast failed.
f_header_ << indent() << " if (!piprot_ || !poprot_) {" << '\n' << indent()
<< " throw ::apache::thrift::TException(\""
<< "TProtocolFactory returned unexpected protocol type in " << service_name_
<< style << "Client" << short_suffix << " constructor\");" << '\n' << indent()
<< " }" << '\n';
} else {
f_header_ << indent() << " piprot_(protocolFactory->getProtocol(itrans_))," << '\n'
<< indent() << " poprot_(protocolFactory->getProtocol(otrans_)) {" << '\n';
}
f_header_ << indent() << " iprot_ = piprot_.get();" << '\n' << indent()
<< " oprot_ = poprot_.get();" << '\n' << indent() << "}" << '\n';
} else {
f_header_ << indent() << " " << extends << style << client_suffix
<< "(channel, protocolFactory) {}" << '\n';
}
}
if (style == "Cob") {
generate_java_doc(f_header_, tservice);
f_header_ << indent()
<< "::std::shared_ptr< ::apache::thrift::async::TAsyncChannel> getChannel() {" << '\n'
<< indent() << " return " << _this << "channel_;" << '\n' << indent() << "}" << '\n';
if (!gen_no_client_completion_) {
f_header_ << indent() << "virtual void completed__(bool /* success */) {}" << '\n';
}
}
vector<t_function*> functions = tservice->get_functions();
vector<t_function*>::const_iterator f_iter;
for (f_iter = functions.begin(); f_iter != functions.end(); ++f_iter) {
generate_java_doc(f_header_, *f_iter);
indent(f_header_) << function_signature(*f_iter, ifstyle)
<< " override;" << '\n';
// TODO(dreiss): Use private inheritance to avoid generating thise in cob-style.
if (style == "Concurrent" && !(*f_iter)->is_oneway()) {
// concurrent clients need to move the seqid from the send function to the
// recv function. Oneway methods don't have a recv function, so we don't need to
// move the seqid for them. Attempting to do so would result in a seqid leak.
t_function send_function(g_type_i32, /*returning seqid*/
string("send_") + (*f_iter)->get_name(),
(*f_iter)->get_arglist());
indent(f_header_) << function_signature(&send_function, "") << ";" << '\n';
} else {
t_function send_function(g_type_void,
string("send_") + (*f_iter)->get_name(),
(*f_iter)->get_arglist());
indent(f_header_) << function_signature(&send_function, "") << ";" << '\n';
}
if (!(*f_iter)->is_oneway()) {
if (style == "Concurrent") {
t_field seqIdArg(g_type_i32, "seqid");
t_struct seqIdArgStruct(program_);
seqIdArgStruct.append(&seqIdArg);
t_function recv_function((*f_iter)->get_returntype(),
string("recv_") + (*f_iter)->get_name(),
&seqIdArgStruct);
indent(f_header_) << function_signature(&recv_function, "") << ";" << '\n';
} else {
t_struct noargs(program_);
t_function recv_function((*f_iter)->get_returntype(),
string("recv_") + (*f_iter)->get_name(),
&noargs);
indent(f_header_) << function_signature(&recv_function, "") << ";" << '\n';
}
}
}
indent_down();
if (extends.empty()) {
f_header_ << " protected:" << '\n';
indent_up();
if (style == "Cob") {
f_header_ << indent()
<< "::std::shared_ptr< ::apache::thrift::async::TAsyncChannel> channel_;" << '\n'
<< indent()
<< "::std::shared_ptr< ::apache::thrift::transport::TMemoryBuffer> itrans_;" << '\n'
<< indent()
<< "::std::shared_ptr< ::apache::thrift::transport::TMemoryBuffer> otrans_;"
<< '\n';
}
f_header_ <<
indent() << prot_ptr << " piprot_;" << '\n' <<
indent() << prot_ptr << " poprot_;" << '\n' <<
indent() << protocol_type << "* iprot_;" << '\n' <<
indent() << protocol_type << "* oprot_;" << '\n';
if (style == "Concurrent") {
f_header_ <<
indent() << "std::shared_ptr< ::apache::thrift::async::TConcurrentClientSyncInfo> sync_;" << '\n';
}
indent_down();
}
f_header_ << "};" << '\n' << '\n';
if (gen_templates_) {
// Output a backwards compatibility typedef using
// TProtocol as the template parameter.
f_header_ << "typedef " << service_name_ << style
<< "ClientT< ::apache::thrift::protocol::TProtocol> " << service_name_ << style
<< "Client;" << '\n' << '\n';
}
string scope = service_name_ + style + client_suffix + "::";
// Generate client method implementations
for (f_iter = functions.begin(); f_iter != functions.end(); ++f_iter) {
string seqIdCapture;
string seqIdUse;
string seqIdCommaUse;
if (style == "Concurrent" && !(*f_iter)->is_oneway()) {
seqIdCapture = "int32_t seqid = ";
seqIdUse = "seqid";
seqIdCommaUse = ", seqid";
}
string funname = (*f_iter)->get_name();
// Open function
if (gen_templates_) {
indent(out) << template_header;
}
indent(out) << function_signature(*f_iter, ifstyle, scope) << '\n';
scope_up(out);
indent(out) << seqIdCapture << "send_" << funname << "(";
// Get the struct of function call params
t_struct* arg_struct = (*f_iter)->get_arglist();
// Declare the function arguments
const vector<t_field*>& fields = arg_struct->get_members();
vector<t_field*>::const_iterator fld_iter;
bool first = true;
for (fld_iter = fields.begin(); fld_iter != fields.end(); ++fld_iter) {
if (first) {
first = false;
} else {
out << ", ";
}
out << (*fld_iter)->get_name();
}
out << ");" << '\n';
if (style != "Cob") {
if (!(*f_iter)->is_oneway()) {
out << indent();
if (!(*f_iter)->get_returntype()->is_void()) {
if (is_complex_type((*f_iter)->get_returntype())) {
out << "recv_" << funname << "(_return" << seqIdCommaUse << ");" << '\n';
} else {
out << "return recv_" << funname << "(" << seqIdUse << ");" << '\n';
}
} else {
out << "recv_" << funname << "(" << seqIdUse << ");" << '\n';
}
}
} else {
if (!(*f_iter)->is_oneway()) {
out << indent() << _this << "channel_->sendAndRecvMessage("
<< "::std::bind(cob, this), " << _this << "otrans_.get(), " << _this << "itrans_.get());"
<< '\n';
} else {
out << indent() << _this << "channel_->sendMessage("
<< "::std::bind(cob, this), " << _this << "otrans_.get());" << '\n';
}
}
scope_down(out);
out << '\n';
// if (style != "Cob") // TODO(dreiss): Libify the client and don't generate this for cob-style
if (true) {
t_type* send_func_return_type = g_type_void;
if (style == "Concurrent" && !(*f_iter)->is_oneway()) {
send_func_return_type = g_type_i32;
}
// Function for sending
t_function send_function(send_func_return_type,
string("send_") + (*f_iter)->get_name(),
(*f_iter)->get_arglist());
// Open the send function
if (gen_templates_) {
indent(out) << template_header;
}
indent(out) << function_signature(&send_function, "", scope) << '\n';
scope_up(out);
// Function arguments and results
string argsname = tservice->get_name() + "_" + (*f_iter)->get_name() + "_pargs";
string resultname = tservice->get_name() + "_" + (*f_iter)->get_name() + "_presult";
string cseqidVal = "0";
if (style == "Concurrent") {
if (!(*f_iter)->is_oneway()) {
cseqidVal = "this->sync_->generateSeqId()";
}
}
// Serialize the request
out <<
indent() << "int32_t cseqid = " << cseqidVal << ";" << '\n';
if(style == "Concurrent") {
out <<
indent() << "::apache::thrift::async::TConcurrentSendSentry sentry(this->sync_.get());" << '\n';
}
if (style == "Cob") {
out <<
indent() << _this << "otrans_->resetBuffer();" << '\n';
}
out <<
indent() << _this << "oprot_->writeMessageBegin(\"" <<
(*f_iter)->get_name() <<
"\", ::apache::thrift::protocol::" << ((*f_iter)->is_oneway() ? "T_ONEWAY" : "T_CALL") <<
", cseqid);" << '\n' << '\n' <<
indent() << argsname << " args;" << '\n';
for (fld_iter = fields.begin(); fld_iter != fields.end(); ++fld_iter) {
out << indent() << "args." << (*fld_iter)->get_name() << " = &" << (*fld_iter)->get_name()
<< ";" << '\n';
}
out << indent() << "args.write(" << _this << "oprot_);" << '\n' << '\n' << indent() << _this
<< "oprot_->writeMessageEnd();" << '\n' << indent() << _this
<< "oprot_->getTransport()->writeEnd();" << '\n' << indent() << _this
<< "oprot_->getTransport()->flush();" << '\n';
if (style == "Concurrent") {
out << '\n' << indent() << "sentry.commit();" << '\n';
if (!(*f_iter)->is_oneway()) {
out << indent() << "return cseqid;" << '\n';
}
}
scope_down(out);
out << '\n';
// Generate recv function only if not an oneway function
if (!(*f_iter)->is_oneway()) {
t_struct noargs(program_);
t_field seqIdArg(g_type_i32, "seqid");
t_struct seqIdArgStruct(program_);
seqIdArgStruct.append(&seqIdArg);
t_struct* recv_function_args = &noargs;
if (style == "Concurrent") {
recv_function_args = &seqIdArgStruct;
}
t_function recv_function((*f_iter)->get_returntype(),
string("recv_") + (*f_iter)->get_name(),
recv_function_args);
// Open the recv function
if (gen_templates_) {
indent(out) << template_header;
}
indent(out) << function_signature(&recv_function, "", scope) << '\n';
scope_up(out);
out << '\n' <<
indent() << "int32_t rseqid = 0;" << '\n' <<
indent() << "std::string fname;" << '\n' <<
indent() << "::apache::thrift::protocol::TMessageType mtype;" << '\n';
if(style == "Concurrent") {
out << '\n' <<
indent() << "// the read mutex gets dropped and reacquired as part of waitForWork()" << '\n' <<
indent() << "// The destructor of this sentry wakes up other clients" << '\n' <<
indent() << "::apache::thrift::async::TConcurrentRecvSentry sentry(this->sync_.get(), seqid);" << '\n';
}
if (style == "Cob" && !gen_no_client_completion_) {
out << indent() << "bool completed = false;" << '\n' << '\n' << indent() << "try {";
indent_up();
}
out << '\n';
if (style == "Concurrent") {
out <<
indent() << "while(true) {" << '\n' <<
indent() << " if(!this->sync_->getPending(fname, mtype, rseqid)) {" << '\n';
indent_up();
indent_up();
}
out <<
indent() << _this << "iprot_->readMessageBegin(fname, mtype, rseqid);" << '\n';
if (style == "Concurrent") {
scope_down(out);
out << indent() << "if(seqid == rseqid) {" << '\n';
indent_up();
}
out <<
indent() << "if (mtype == ::apache::thrift::protocol::T_EXCEPTION) {" << '\n' <<
indent() << " ::apache::thrift::TApplicationException x;" << '\n' <<
indent() << " x.read(" << _this << "iprot_);" << '\n' <<
indent() << " " << _this << "iprot_->readMessageEnd();" << '\n' <<
indent() << " " << _this << "iprot_->getTransport()->readEnd();" << '\n';
if (style == "Cob" && !gen_no_client_completion_) {
out << indent() << " completed = true;" << '\n' << indent() << " completed__(true);"
<< '\n';
}
if (style == "Concurrent") {
out << indent() << " sentry.commit();" << '\n';
}
out <<
indent() << " throw x;" << '\n' <<
indent() << "}" << '\n' <<
indent() << "if (mtype != ::apache::thrift::protocol::T_REPLY) {" << '\n' <<
indent() << " " << _this << "iprot_->skip(" << "::apache::thrift::protocol::T_STRUCT);" << '\n' <<
indent() << " " << _this << "iprot_->readMessageEnd();" << '\n' <<
indent() << " " << _this << "iprot_->getTransport()->readEnd();" << '\n';
if (style == "Cob" && !gen_no_client_completion_) {
out << indent() << " completed = true;" << '\n' << indent() << " completed__(false);"
<< '\n';
}
out <<
indent() << "}" << '\n' <<
indent() << "if (fname.compare(\"" << (*f_iter)->get_name() << "\") != 0) {" << '\n' <<
indent() << " " << _this << "iprot_->skip(" << "::apache::thrift::protocol::T_STRUCT);" << '\n' <<
indent() << " " << _this << "iprot_->readMessageEnd();" << '\n' <<
indent() << " " << _this << "iprot_->getTransport()->readEnd();" << '\n';
if (style == "Cob" && !gen_no_client_completion_) {
out << indent() << " completed = true;" << '\n' << indent() << " completed__(false);"
<< '\n';
}
if (style == "Concurrent") {
out << '\n' <<
indent() << " // in a bad state, don't commit" << '\n' <<
indent() << " using ::apache::thrift::protocol::TProtocolException;" << '\n' <<
indent() << " throw TProtocolException(TProtocolException::INVALID_DATA);" << '\n';
}
out << indent() << "}" << '\n';
if (!(*f_iter)->get_returntype()->is_void()
&& !is_complex_type((*f_iter)->get_returntype())) {
t_field returnfield((*f_iter)->get_returntype(), "_return");
out << indent() << declare_field(&returnfield) << '\n';
}
out << indent() << resultname << " result;" << '\n';
if (!(*f_iter)->get_returntype()->is_void()) {
out << indent() << "result.success = &_return;" << '\n';
}
out << indent() << "result.read(" << _this << "iprot_);" << '\n' << indent() << _this
<< "iprot_->readMessageEnd();" << '\n' << indent() << _this
<< "iprot_->getTransport()->readEnd();" << '\n' << '\n';
// Careful, only look for _result if not a void function
if (!(*f_iter)->get_returntype()->is_void()) {
if (is_complex_type((*f_iter)->get_returntype())) {
out <<
indent() << "if (result.__isset.success) {" << '\n';
out <<
indent() << " // _return pointer has now been filled" << '\n';
if (style == "Cob" && !gen_no_client_completion_) {
out << indent() << " completed = true;" << '\n' << indent() << " completed__(true);"
<< '\n';
}
if (style == "Concurrent") {
out << indent() << " sentry.commit();" << '\n';
}
out <<
indent() << " return;" << '\n' <<
indent() << "}" << '\n';
} else {
out << indent() << "if (result.__isset.success) {" << '\n';
if (style == "Cob" && !gen_no_client_completion_) {
out << indent() << " completed = true;" << '\n' << indent() << " completed__(true);"
<< '\n';
}
if (style == "Concurrent") {
out << indent() << " sentry.commit();" << '\n';
}
out << indent() << " return _return;" << '\n' << indent() << "}" << '\n';
}
}
t_struct* xs = (*f_iter)->get_xceptions();
const std::vector<t_field*>& xceptions = xs->get_members();
vector<t_field*>::const_iterator x_iter;
for (x_iter = xceptions.begin(); x_iter != xceptions.end(); ++x_iter) {
out << indent() << "if (result.__isset." << (*x_iter)->get_name() << ") {" << '\n';
if (style == "Cob" && !gen_no_client_completion_) {
out << indent() << " completed = true;" << '\n' << indent() << " completed__(true);"
<< '\n';
}
if (style == "Concurrent") {
out << indent() << " sentry.commit();" << '\n';
}
out << indent() << " throw result." << (*x_iter)->get_name() << ";" << '\n' << indent()
<< "}" << '\n';
}
// We only get here if we are a void function
if ((*f_iter)->get_returntype()->is_void()) {
if (style == "Cob" && !gen_no_client_completion_) {
out << indent() << "completed = true;" << '\n' << indent() << "completed__(true);"
<< '\n';
}
if (style == "Concurrent") {
out << indent() << "sentry.commit();" << '\n';
}
indent(out) << "return;" << '\n';
} else {
if (style == "Cob" && !gen_no_client_completion_) {
out << indent() << "completed = true;" << '\n' << indent() << "completed__(true);"
<< '\n';
}
if (style == "Concurrent") {
out << indent() << "// in a bad state, don't commit" << '\n';
}
out << indent() << "throw "
"::apache::thrift::TApplicationException(::apache::thrift::"
"TApplicationException::MISSING_RESULT, \"" << (*f_iter)->get_name()
<< " failed: unknown result\");" << '\n';
}
if (style == "Concurrent") {
indent_down();
indent_down();
out << indent() << " }" << '\n'
<< indent() << " // seqid != rseqid" << '\n'
<< indent() << " this->sync_->updatePending(fname, mtype, rseqid);" << '\n'
<< '\n'
<< indent()
<< " // this will temporarily unlock the readMutex, and let other clients get work done" << '\n'
<< indent() << " this->sync_->waitForWork(seqid);" << '\n'
<< indent() << "} // end while(true)" << '\n';
}
if (style == "Cob" && !gen_no_client_completion_) {
indent_down();
out << indent() << "} catch (...) {" << '\n' << indent() << " if (!completed) {" << '\n'
<< indent() << " completed__(false);" << '\n' << indent() << " }" << '\n'
<< indent() << " throw;" << '\n' << indent() << "}" << '\n';
}
// Close function
scope_down(out);
out << '\n';
}
}
}
}