void t_cpp_generator::generate_service_client()

in compiler/cpp/src/thrift/generate/t_cpp_generator.cc [2519:3102]


void t_cpp_generator::generate_service_client(t_service* tservice, string style) {
  string ifstyle;
  if (style == "Cob") {
    ifstyle = "CobCl";
  }

  std::ostream& out = (gen_templates_ ? f_service_tcc_ : f_service_);
  string template_header, template_suffix, short_suffix, protocol_type, _this;
  string const prot_factory_type = "::apache::thrift::protocol::TProtocolFactory";
  if (gen_templates_) {
    template_header = "template <class Protocol_>\n";
    short_suffix = "T";
    template_suffix = "T<Protocol_>";
    protocol_type = "Protocol_";
    _this = "this->";
  } else {
    protocol_type = "::apache::thrift::protocol::TProtocol";
  }
  string prot_ptr = "std::shared_ptr< " + protocol_type + ">";
  string client_suffix = "Client" + template_suffix;
  string if_suffix = "If";
  if (style == "Cob") {
    if_suffix += template_suffix;
  }

  string extends = "";
  string extends_client = "";
  if (tservice->get_extends() != nullptr) {
    // TODO(simpkins): If gen_templates_ is enabled, we currently assume all
    // parent services were also generated with templates enabled.
    extends = type_name(tservice->get_extends());
    extends_client = ", public " + extends + style + client_suffix;
  }

  // Generate the header portion
  if (style == "Concurrent") {
    f_header_ << "// The \'concurrent\' client is a thread safe client that correctly handles\n"
                 "// out of order responses.  It is slower than the regular client, so should\n"
                 "// only be used when you need to share a connection among multiple threads\n";
  }
  f_header_ << template_header << "class " << service_name_ << style << "Client" << short_suffix
            << " : "
            << "virtual public " << service_name_ << ifstyle << if_suffix << extends_client << " {"
            << '\n' << " public:" << '\n';

  indent_up();
  if (style != "Cob") {
    f_header_ << indent() << service_name_ << style << "Client" << short_suffix << "(" << prot_ptr
        << " prot";
    if (style == "Concurrent") {
        f_header_ << ", std::shared_ptr< ::apache::thrift::async::TConcurrentClientSyncInfo> sync";
    }
    f_header_ << ") ";

    if (extends.empty()) {
      if (style == "Concurrent") {
        f_header_ << ": sync_(sync)" << '\n';
      }
      f_header_ << "{" << '\n';
      f_header_ << indent() << "  setProtocol" << short_suffix << "(prot);" << '\n' << indent()
                << "}" << '\n';
    } else {
      f_header_ << ":" << '\n';
      f_header_ << indent() << "  " << extends << style << client_suffix << "(prot, prot";
      if (style == "Concurrent") {
          f_header_ << ", sync";
      }
      f_header_ << ") {}" << '\n';
    }

    f_header_ << indent() << service_name_ << style << "Client" << short_suffix << "(" << prot_ptr
        << " iprot, " << prot_ptr << " oprot";
    if (style == "Concurrent") {
        f_header_ << ", std::shared_ptr< ::apache::thrift::async::TConcurrentClientSyncInfo> sync";
    }
    f_header_ << ") ";

    if (extends.empty()) {
      if (style == "Concurrent") {
        f_header_ << ": sync_(sync)" << '\n';
      }
      f_header_ << "{" << '\n';
      f_header_ << indent() << "  setProtocol" << short_suffix << "(iprot,oprot);" << '\n'
                << indent() << "}" << '\n';
    } else {
      f_header_ << ":" << indent() << "  " << extends << style << client_suffix
                << "(iprot, oprot";
      if (style == "Concurrent") {
          f_header_ << ", sync";
      }
      f_header_ << ") {}" << '\n';
    }

    // create the setProtocol methods
    if (extends.empty()) {
      f_header_ << " private:" << '\n';
      // 1: one parameter
      f_header_ << indent() << "void setProtocol" << short_suffix << "(" << prot_ptr << " prot) {"
                << '\n';
      f_header_ << indent() << "setProtocol" << short_suffix << "(prot,prot);" << '\n';
      f_header_ << indent() << "}" << '\n';
      // 2: two parameter
      f_header_ << indent() << "void setProtocol" << short_suffix << "(" << prot_ptr << " iprot, "
                << prot_ptr << " oprot) {" << '\n';

      f_header_ << indent() << "  piprot_=iprot;" << '\n' << indent() << "  poprot_=oprot;" << '\n'
                << indent() << "  iprot_ = iprot.get();" << '\n' << indent()
                << "  oprot_ = oprot.get();" << '\n';

      f_header_ << indent() << "}" << '\n';
      f_header_ << " public:" << '\n';
    }

    // Generate getters for the protocols.
    // Note that these are not currently templated for simplicity.
    // TODO(simpkins): should they be templated?
    f_header_ << indent()
              << "std::shared_ptr< ::apache::thrift::protocol::TProtocol> getInputProtocol() {"
              << '\n' << indent() << "  return " << _this << "piprot_;" << '\n' << indent() << "}"
              << '\n';

    f_header_ << indent()
              << "std::shared_ptr< ::apache::thrift::protocol::TProtocol> getOutputProtocol() {"
              << '\n' << indent() << "  return " << _this << "poprot_;" << '\n' << indent() << "}"
              << '\n';

  } else /* if (style == "Cob") */ {
    f_header_ << indent() << service_name_ << style << "Client" << short_suffix << "("
              << "std::shared_ptr< ::apache::thrift::async::TAsyncChannel> channel, "
              << "::apache::thrift::protocol::TProtocolFactory* protocolFactory) :" << '\n';
    if (extends.empty()) {
      f_header_ << indent() << "  channel_(channel)," << '\n' << indent()
                << "  itrans_(new ::apache::thrift::transport::TMemoryBuffer())," << '\n'
                << indent() << "  otrans_(new ::apache::thrift::transport::TMemoryBuffer()),"
                << '\n';
      if (gen_templates_) {
        // TProtocolFactory classes return generic TProtocol pointers.
        // We have to dynamic cast to the Protocol_ type we are expecting.
        f_header_ << indent() << "  piprot_(::std::dynamic_pointer_cast<Protocol_>("
                  << "protocolFactory->getProtocol(itrans_)))," << '\n' << indent()
                  << "  poprot_(::std::dynamic_pointer_cast<Protocol_>("
                  << "protocolFactory->getProtocol(otrans_))) {" << '\n';
        // Throw a TException if either dynamic cast failed.
        f_header_ << indent() << "  if (!piprot_ || !poprot_) {" << '\n' << indent()
                  << "    throw ::apache::thrift::TException(\""
                  << "TProtocolFactory returned unexpected protocol type in " << service_name_
                  << style << "Client" << short_suffix << " constructor\");" << '\n' << indent()
                  << "  }" << '\n';
      } else {
        f_header_ << indent() << "  piprot_(protocolFactory->getProtocol(itrans_))," << '\n'
                  << indent() << "  poprot_(protocolFactory->getProtocol(otrans_)) {" << '\n';
      }
      f_header_ << indent() << "  iprot_ = piprot_.get();" << '\n' << indent()
                << "  oprot_ = poprot_.get();" << '\n' << indent() << "}" << '\n';
    } else {
      f_header_ << indent() << "  " << extends << style << client_suffix
                << "(channel, protocolFactory) {}" << '\n';
    }
  }

  if (style == "Cob") {
    generate_java_doc(f_header_, tservice);

    f_header_ << indent()
              << "::std::shared_ptr< ::apache::thrift::async::TAsyncChannel> getChannel() {" << '\n'
              << indent() << "  return " << _this << "channel_;" << '\n' << indent() << "}" << '\n';
    if (!gen_no_client_completion_) {
      f_header_ << indent() << "virtual void completed__(bool /* success */) {}" << '\n';
    }
  }

  vector<t_function*> functions = tservice->get_functions();
  vector<t_function*>::const_iterator f_iter;
  for (f_iter = functions.begin(); f_iter != functions.end(); ++f_iter) {
    generate_java_doc(f_header_, *f_iter);
    indent(f_header_) << function_signature(*f_iter, ifstyle)
                      << " override;" << '\n';
    // TODO(dreiss): Use private inheritance to avoid generating thise in cob-style.
    if (style == "Concurrent" && !(*f_iter)->is_oneway()) {
      // concurrent clients need to move the seqid from the send function to the
      // recv function.  Oneway methods don't have a recv function, so we don't need to
      // move the seqid for them.  Attempting to do so would result in a seqid leak.
      t_function send_function(g_type_i32, /*returning seqid*/
                               string("send_") + (*f_iter)->get_name(),
                               (*f_iter)->get_arglist());
      indent(f_header_) << function_signature(&send_function, "") << ";" << '\n';
    } else {
      t_function send_function(g_type_void,
                               string("send_") + (*f_iter)->get_name(),
                               (*f_iter)->get_arglist());
      indent(f_header_) << function_signature(&send_function, "") << ";" << '\n';
    }
    if (!(*f_iter)->is_oneway()) {
      if (style == "Concurrent") {
        t_field seqIdArg(g_type_i32, "seqid");
        t_struct seqIdArgStruct(program_);
        seqIdArgStruct.append(&seqIdArg);
        t_function recv_function((*f_iter)->get_returntype(),
                                 string("recv_") + (*f_iter)->get_name(),
                                 &seqIdArgStruct);
        indent(f_header_) << function_signature(&recv_function, "") << ";" << '\n';
      } else {
        t_struct noargs(program_);
        t_function recv_function((*f_iter)->get_returntype(),
                                 string("recv_") + (*f_iter)->get_name(),
                                 &noargs);
        indent(f_header_) << function_signature(&recv_function, "") << ";" << '\n';
      }
    }
  }
  indent_down();

  if (extends.empty()) {
    f_header_ << " protected:" << '\n';
    indent_up();

    if (style == "Cob") {
      f_header_ << indent()
                << "::std::shared_ptr< ::apache::thrift::async::TAsyncChannel> channel_;" << '\n'
                << indent()
                << "::std::shared_ptr< ::apache::thrift::transport::TMemoryBuffer> itrans_;" << '\n'
                << indent()
                << "::std::shared_ptr< ::apache::thrift::transport::TMemoryBuffer> otrans_;"
                << '\n';
    }
    f_header_ <<
      indent() << prot_ptr << " piprot_;" << '\n' <<
      indent() << prot_ptr << " poprot_;" << '\n' <<
      indent() << protocol_type << "* iprot_;" << '\n' <<
      indent() << protocol_type << "* oprot_;" << '\n';

    if (style == "Concurrent") {
      f_header_ <<
        indent() << "std::shared_ptr< ::apache::thrift::async::TConcurrentClientSyncInfo> sync_;" << '\n';
    }
    indent_down();
  }

  f_header_ << "};" << '\n' << '\n';

  if (gen_templates_) {
    // Output a backwards compatibility typedef using
    // TProtocol as the template parameter.
    f_header_ << "typedef " << service_name_ << style
              << "ClientT< ::apache::thrift::protocol::TProtocol> " << service_name_ << style
              << "Client;" << '\n' << '\n';
  }

  string scope = service_name_ + style + client_suffix + "::";

  // Generate client method implementations
  for (f_iter = functions.begin(); f_iter != functions.end(); ++f_iter) {
    string seqIdCapture;
    string seqIdUse;
    string seqIdCommaUse;
    if (style == "Concurrent" && !(*f_iter)->is_oneway()) {
      seqIdCapture = "int32_t seqid = ";
      seqIdUse = "seqid";
      seqIdCommaUse = ", seqid";
    }

    string funname = (*f_iter)->get_name();

    // Open function
    if (gen_templates_) {
      indent(out) << template_header;
    }
    indent(out) << function_signature(*f_iter, ifstyle, scope) << '\n';
    scope_up(out);
    indent(out) << seqIdCapture << "send_" << funname << "(";

    // Get the struct of function call params
    t_struct* arg_struct = (*f_iter)->get_arglist();

    // Declare the function arguments
    const vector<t_field*>& fields = arg_struct->get_members();
    vector<t_field*>::const_iterator fld_iter;
    bool first = true;
    for (fld_iter = fields.begin(); fld_iter != fields.end(); ++fld_iter) {
      if (first) {
        first = false;
      } else {
        out << ", ";
      }
      out << (*fld_iter)->get_name();
    }
    out << ");" << '\n';

    if (style != "Cob") {
      if (!(*f_iter)->is_oneway()) {
        out << indent();
        if (!(*f_iter)->get_returntype()->is_void()) {
          if (is_complex_type((*f_iter)->get_returntype())) {
            out << "recv_" << funname << "(_return" << seqIdCommaUse << ");" << '\n';
          } else {
            out << "return recv_" << funname << "(" << seqIdUse << ");" << '\n';
          }
        } else {
          out << "recv_" << funname << "(" << seqIdUse << ");" << '\n';
        }
      }
    } else {
      if (!(*f_iter)->is_oneway()) {
        out << indent() << _this << "channel_->sendAndRecvMessage("
            << "::std::bind(cob, this), " << _this << "otrans_.get(), " << _this << "itrans_.get());"
            << '\n';
      } else {
        out << indent() << _this << "channel_->sendMessage("
            << "::std::bind(cob, this), " << _this << "otrans_.get());" << '\n';
      }
    }
    scope_down(out);
    out << '\n';

    // if (style != "Cob") // TODO(dreiss): Libify the client and don't generate this for cob-style
    if (true) {
      t_type* send_func_return_type = g_type_void;
      if (style == "Concurrent" && !(*f_iter)->is_oneway()) {
        send_func_return_type = g_type_i32;
      }
      // Function for sending
      t_function send_function(send_func_return_type,
                               string("send_") + (*f_iter)->get_name(),
                               (*f_iter)->get_arglist());

      // Open the send function
      if (gen_templates_) {
        indent(out) << template_header;
      }
      indent(out) << function_signature(&send_function, "", scope) << '\n';
      scope_up(out);

      // Function arguments and results
      string argsname = tservice->get_name() + "_" + (*f_iter)->get_name() + "_pargs";
      string resultname = tservice->get_name() + "_" + (*f_iter)->get_name() + "_presult";

      string cseqidVal = "0";
      if (style == "Concurrent") {
        if (!(*f_iter)->is_oneway()) {
          cseqidVal = "this->sync_->generateSeqId()";
        }
      }
      // Serialize the request
      out <<
        indent() << "int32_t cseqid = " << cseqidVal << ";" << '\n';
      if(style == "Concurrent") {
        out <<
          indent() << "::apache::thrift::async::TConcurrentSendSentry sentry(this->sync_.get());" << '\n';
      }
      if (style == "Cob") {
        out <<
          indent() << _this << "otrans_->resetBuffer();" << '\n';
      }
      out <<
        indent() << _this << "oprot_->writeMessageBegin(\"" <<
        (*f_iter)->get_name() <<
        "\", ::apache::thrift::protocol::" << ((*f_iter)->is_oneway() ? "T_ONEWAY" : "T_CALL") <<
        ", cseqid);" << '\n' << '\n' <<
        indent() << argsname << " args;" << '\n';

      for (fld_iter = fields.begin(); fld_iter != fields.end(); ++fld_iter) {
        out << indent() << "args." << (*fld_iter)->get_name() << " = &" << (*fld_iter)->get_name()
            << ";" << '\n';
      }

      out << indent() << "args.write(" << _this << "oprot_);" << '\n' << '\n' << indent() << _this
          << "oprot_->writeMessageEnd();" << '\n' << indent() << _this
          << "oprot_->getTransport()->writeEnd();" << '\n' << indent() << _this
          << "oprot_->getTransport()->flush();" << '\n';

      if (style == "Concurrent") {
        out << '\n' << indent() << "sentry.commit();" << '\n';

        if (!(*f_iter)->is_oneway()) {
          out << indent() << "return cseqid;" << '\n';
        }
      }
      scope_down(out);
      out << '\n';

      // Generate recv function only if not an oneway function
      if (!(*f_iter)->is_oneway()) {
        t_struct noargs(program_);

        t_field seqIdArg(g_type_i32, "seqid");
        t_struct seqIdArgStruct(program_);
        seqIdArgStruct.append(&seqIdArg);

        t_struct* recv_function_args = &noargs;
        if (style == "Concurrent") {
          recv_function_args = &seqIdArgStruct;
        }

        t_function recv_function((*f_iter)->get_returntype(),
                                 string("recv_") + (*f_iter)->get_name(),
                                 recv_function_args);
        // Open the recv function
        if (gen_templates_) {
          indent(out) << template_header;
        }
        indent(out) << function_signature(&recv_function, "", scope) << '\n';
        scope_up(out);

        out << '\n' <<
          indent() << "int32_t rseqid = 0;" << '\n' <<
          indent() << "std::string fname;" << '\n' <<
          indent() << "::apache::thrift::protocol::TMessageType mtype;" << '\n';
        if(style == "Concurrent") {
          out << '\n' <<
            indent() << "// the read mutex gets dropped and reacquired as part of waitForWork()" << '\n' <<
            indent() << "// The destructor of this sentry wakes up other clients" << '\n' <<
            indent() << "::apache::thrift::async::TConcurrentRecvSentry sentry(this->sync_.get(), seqid);" << '\n';
        }
        if (style == "Cob" && !gen_no_client_completion_) {
          out << indent() << "bool completed = false;" << '\n' << '\n' << indent() << "try {";
          indent_up();
        }
        out << '\n';
        if (style == "Concurrent") {
          out <<
            indent() << "while(true) {" << '\n' <<
            indent() << "  if(!this->sync_->getPending(fname, mtype, rseqid)) {" << '\n';
          indent_up();
          indent_up();
        }
        out <<
          indent() << _this << "iprot_->readMessageBegin(fname, mtype, rseqid);" << '\n';
        if (style == "Concurrent") {
          scope_down(out);
          out << indent() << "if(seqid == rseqid) {" << '\n';
          indent_up();
        }
        out <<
          indent() << "if (mtype == ::apache::thrift::protocol::T_EXCEPTION) {" << '\n' <<
          indent() << "  ::apache::thrift::TApplicationException x;" << '\n' <<
          indent() << "  x.read(" << _this << "iprot_);" << '\n' <<
          indent() << "  " << _this << "iprot_->readMessageEnd();" << '\n' <<
          indent() << "  " << _this << "iprot_->getTransport()->readEnd();" << '\n';
        if (style == "Cob" && !gen_no_client_completion_) {
          out << indent() << "  completed = true;" << '\n' << indent() << "  completed__(true);"
              << '\n';
        }
        if (style == "Concurrent") {
          out << indent() << "  sentry.commit();" << '\n';
        }
        out <<
          indent() << "  throw x;" << '\n' <<
          indent() << "}" << '\n' <<
          indent() << "if (mtype != ::apache::thrift::protocol::T_REPLY) {" << '\n' <<
          indent() << "  " << _this << "iprot_->skip(" << "::apache::thrift::protocol::T_STRUCT);" << '\n' <<
          indent() << "  " << _this << "iprot_->readMessageEnd();" << '\n' <<
          indent() << "  " << _this << "iprot_->getTransport()->readEnd();" << '\n';
        if (style == "Cob" && !gen_no_client_completion_) {
          out << indent() << "  completed = true;" << '\n' << indent() << "  completed__(false);"
              << '\n';
        }
        out <<
          indent() << "}" << '\n' <<
          indent() << "if (fname.compare(\"" << (*f_iter)->get_name() << "\") != 0) {" << '\n' <<
          indent() << "  " << _this << "iprot_->skip(" << "::apache::thrift::protocol::T_STRUCT);" << '\n' <<
          indent() << "  " << _this << "iprot_->readMessageEnd();" << '\n' <<
          indent() << "  " << _this << "iprot_->getTransport()->readEnd();" << '\n';
        if (style == "Cob" && !gen_no_client_completion_) {
          out << indent() << "  completed = true;" << '\n' << indent() << "  completed__(false);"
              << '\n';
        }
        if (style == "Concurrent") {
          out << '\n' <<
            indent() << "  // in a bad state, don't commit" << '\n' <<
            indent() << "  using ::apache::thrift::protocol::TProtocolException;" << '\n' <<
            indent() << "  throw TProtocolException(TProtocolException::INVALID_DATA);" << '\n';
        }
        out << indent() << "}" << '\n';

        if (!(*f_iter)->get_returntype()->is_void()
            && !is_complex_type((*f_iter)->get_returntype())) {
          t_field returnfield((*f_iter)->get_returntype(), "_return");
          out << indent() << declare_field(&returnfield) << '\n';
        }

        out << indent() << resultname << " result;" << '\n';

        if (!(*f_iter)->get_returntype()->is_void()) {
          out << indent() << "result.success = &_return;" << '\n';
        }

        out << indent() << "result.read(" << _this << "iprot_);" << '\n' << indent() << _this
            << "iprot_->readMessageEnd();" << '\n' << indent() << _this
            << "iprot_->getTransport()->readEnd();" << '\n' << '\n';

        // Careful, only look for _result if not a void function
        if (!(*f_iter)->get_returntype()->is_void()) {
          if (is_complex_type((*f_iter)->get_returntype())) {
            out <<
              indent() << "if (result.__isset.success) {" << '\n';
            out <<
              indent() << "  // _return pointer has now been filled" << '\n';
            if (style == "Cob" && !gen_no_client_completion_) {
              out << indent() << "  completed = true;" << '\n' << indent() << "  completed__(true);"
                  << '\n';
            }
            if (style == "Concurrent") {
              out << indent() << "  sentry.commit();" << '\n';
            }
            out <<
              indent() << "  return;" << '\n' <<
              indent() << "}" << '\n';
          } else {
            out << indent() << "if (result.__isset.success) {" << '\n';
            if (style == "Cob" && !gen_no_client_completion_) {
              out << indent() << "  completed = true;" << '\n' << indent() << "  completed__(true);"
                  << '\n';
            }
            if (style == "Concurrent") {
              out << indent() << "  sentry.commit();" << '\n';
            }
            out << indent() << "  return _return;" << '\n' << indent() << "}" << '\n';
          }
        }

        t_struct* xs = (*f_iter)->get_xceptions();
        const std::vector<t_field*>& xceptions = xs->get_members();
        vector<t_field*>::const_iterator x_iter;
        for (x_iter = xceptions.begin(); x_iter != xceptions.end(); ++x_iter) {
          out << indent() << "if (result.__isset." << (*x_iter)->get_name() << ") {" << '\n';
          if (style == "Cob" && !gen_no_client_completion_) {
            out << indent() << "  completed = true;" << '\n' << indent() << "  completed__(true);"
                << '\n';
          }
          if (style == "Concurrent") {
            out << indent() << "  sentry.commit();" << '\n';
          }
          out << indent() << "  throw result." << (*x_iter)->get_name() << ";" << '\n' << indent()
              << "}" << '\n';
        }

        // We only get here if we are a void function
        if ((*f_iter)->get_returntype()->is_void()) {
          if (style == "Cob" && !gen_no_client_completion_) {
            out << indent() << "completed = true;" << '\n' << indent() << "completed__(true);"
                << '\n';
          }
          if (style == "Concurrent") {
            out << indent() << "sentry.commit();" << '\n';
          }
          indent(out) << "return;" << '\n';
        } else {
          if (style == "Cob" && !gen_no_client_completion_) {
            out << indent() << "completed = true;" << '\n' << indent() << "completed__(true);"
                << '\n';
          }
          if (style == "Concurrent") {
            out << indent() << "// in a bad state, don't commit" << '\n';
          }
          out << indent() << "throw "
                             "::apache::thrift::TApplicationException(::apache::thrift::"
                             "TApplicationException::MISSING_RESULT, \"" << (*f_iter)->get_name()
              << " failed: unknown result\");" << '\n';
        }
        if (style == "Concurrent") {
          indent_down();
          indent_down();
          out << indent() << "  }" << '\n'
              << indent() << "  // seqid != rseqid" << '\n'
              << indent() << "  this->sync_->updatePending(fname, mtype, rseqid);" << '\n'
              << '\n'
              << indent()
              << "  // this will temporarily unlock the readMutex, and let other clients get work done" << '\n'
              << indent() << "  this->sync_->waitForWork(seqid);" << '\n'
              << indent() << "} // end while(true)" << '\n';
        }
        if (style == "Cob" && !gen_no_client_completion_) {
          indent_down();
          out << indent() << "} catch (...) {" << '\n' << indent() << "  if (!completed) {" << '\n'
              << indent() << "    completed__(false);" << '\n' << indent() << "  }" << '\n'
              << indent() << "  throw;" << '\n' << indent() << "}" << '\n';
        }
        // Close function
        scope_down(out);
        out << '\n';
      }
    }
  }
}