void t_py_generator::generate_service_client()

in thrift/compiler/generate/t_py_generator.cc [2232:2592]


void t_py_generator::generate_service_client(const t_service* tservice) {
  string extends = "";
  string extends_client = "";
  if (tservice->get_extends() != nullptr) {
    extends = type_name(tservice->get_extends());
    extends_client = extends + ".Client, ";
  }

  f_service_ << "class Client(" << extends_client << "Iface):" << endl;
  indent_up();
  generate_python_docstring(f_service_, tservice);

  f_service_ << indent() << "_fbthrift_force_cpp_transport = "
             << (gen_cpp_transport_ ? "True" : "False") << endl
             << endl;

  // Context Handlers
  if (!gen_asyncio_) {
    f_service_ << indent() << "def __enter__(self):" << endl
               << indent() << "  if self._fbthrift_cpp_transport:" << endl
               << indent() << "    self._fbthrift_cpp_transport.__enter__()"
               << endl
               << indent() << "  return self" << endl
               << endl;
    f_service_ << indent() << "def __exit__(self, type, value, tb):" << endl
               << indent() << "  if self._fbthrift_cpp_transport:" << endl
               << indent()
               << "    self._fbthrift_cpp_transport.__exit__(type, value, tb)"
               << endl
               << indent() << "  if self._iprot:" << endl
               << indent() << "    self._iprot.trans.close()" << endl
               << indent()
               << "  if self._oprot and self._iprot is not self._oprot:" << endl
               << indent() << "    self._oprot.trans.close()" << endl
               << endl;
  }

  // Constructor function
  if (gen_asyncio_) {
    f_service_
        << indent()
        << "def __init__(self, oprot=None, loop=None, cpp_transport=None):"
        << endl;
  } else {
    f_service_
        << indent()
        << "def __init__(self, iprot=None, oprot=None, cpp_transport=None):"
        << endl;
  }
  if (extends.empty()) {
    if (gen_asyncio_) {
      f_service_ << indent() << "  self._oprot = oprot" << endl
                 << indent()
                 << "  self._loop = loop or asyncio.get_event_loop()" << endl
                 << indent() << "  self._seqid = 0" << endl
                 << indent() << "  self._futures = {}" << endl
                 << indent() << "  self._fbthrift_cpp_transport = None" << endl
                 << endl;
    } else {
      f_service_ << indent() << "  self._iprot = self._oprot = iprot" << endl
                 << indent() << "  if oprot != None:" << endl
                 << indent() << "    self._oprot = oprot" << endl
                 << indent() << "  self._seqid = 0" << endl
                 << indent() << "  self._fbthrift_cpp_transport = cpp_transport"
                 << endl
                 << endl;
    }
  } else {
    if (gen_asyncio_) {
      f_service_ << indent() << "  " << extends
                 << ".Client.__init__(self, oprot, loop)" << endl
                 << endl;
    } else {
      f_service_ << indent() << "  " << extends
                 << ".Client.__init__(self, iprot, oprot, cpp_transport)"
                 << endl
                 << endl;
    }
  }

  // Helpers
  f_service_
      << indent() << "def set_persistent_header(self, key, value):" << endl
      << indent() << "  if self._fbthrift_cpp_transport:" << endl
      << indent()
      << "    self._fbthrift_cpp_transport.set_persistent_header(key, value)"
      << endl
      << indent() << "  else:" << endl
      << indent() << "    try:" << endl
      << indent() << "      self._oprot.trans.set_persistent_header(key, value)"
      << endl
      << indent() << "    except AttributeError:" << endl
      << indent() << "      pass" << endl
      << endl;

  f_service_
      << indent() << "def get_persistent_headers(self):" << endl
      << indent() << "  if self._fbthrift_cpp_transport:" << endl
      << indent()
      << "    return self._fbthrift_cpp_transport.get_persistent_headers()"
      << endl
      << indent() << "  try:" << endl
      << indent()
      << "    return self._oprot.trans.get_write_persistent_headers()" << endl
      << indent() << "  except AttributeError:" << endl
      << indent() << "    return {}" << endl
      << endl;

  f_service_ << indent() << "def clear_persistent_headers(self):" << endl
             << indent() << "  if self._fbthrift_cpp_transport:" << endl
             << indent()
             << "    self._fbthrift_cpp_transport.clear_persistent_headers()"
             << endl
             << indent() << "  else:" << endl
             << indent() << "    try:" << endl
             << indent() << "      self._oprot.trans.clear_persistent_headers()"
             << endl
             << indent() << "    except AttributeError:" << endl
             << indent() << "      pass" << endl
             << endl;

  f_service_
      << indent() << "def set_onetime_header(self, key, value):" << endl
      << indent() << "  if self._fbthrift_cpp_transport:" << endl
      << indent()
      << "    self._fbthrift_cpp_transport.set_onetime_header(key, value)"
      << endl
      << indent() << "  else:" << endl
      << indent() << "    try:" << endl
      << indent() << "      self._oprot.trans.set_header(key, value)" << endl
      << indent() << "    except AttributeError:" << endl
      << indent() << "      pass" << endl
      << endl;

  f_service_ << indent() << "def set_max_frame_size(self, size):" << endl
             << indent() << "  if self._fbthrift_cpp_transport:" << endl
             << indent() << "    pass" << endl
             << indent() << "  else:" << endl
             << indent() << "    try:" << endl
             << indent() << "      self._oprot.trans.set_max_frame_size(size)"
             << endl
             << indent() << "    except AttributeError:" << endl
             << indent() << "      pass" << endl
             << endl;

  // Generate client method implementations
  const auto& functions = get_functions(tservice);
  vector<t_function*>::const_iterator f_iter;
  for (f_iter = functions.begin(); f_iter != functions.end(); ++f_iter) {
    const t_struct* arg_struct = (*f_iter)->get_paramlist();
    const vector<t_field*>& fields = arg_struct->get_members();
    vector<t_field*>::const_iterator fld_iter;
    string funname = rename_reserved_keywords((*f_iter)->get_name());
    string argsname = (*f_iter)->get_name() + "_args";

    // Open function
    indent(f_service_) << "def " << function_signature(*f_iter) << ":" << endl;
    indent_up();
    generate_python_docstring(f_service_, (*f_iter));

    // CPP transport
    if (!gen_asyncio_) {
      indent(f_service_) << "if (self._fbthrift_cpp_transport):" << endl;
      indent_up();
      f_service_ << indent() << "args = " << argsname << "()" << endl;
      for (fld_iter = fields.begin(); fld_iter != fields.end(); ++fld_iter) {
        f_service_ << indent() << "args."
                   << rename_reserved_keywords((*fld_iter)->get_name()) << " = "
                   << rename_reserved_keywords((*fld_iter)->get_name()) << endl;
      }
      f_service_ << indent()
                 << "return self._fbthrift_cpp_transport._send_request(\""
                 << tservice->get_name() << "\", \"" << (*f_iter)->get_name()
                 << "\", args, " << (*f_iter)->get_name() << "_result).success"
                 << endl;
      indent_down();
    }

    if (gen_asyncio_) {
      indent(f_service_) << "self._seqid += 1" << endl;
      indent(f_service_)
          << "fut = self._futures[self._seqid] = asyncio.Future(loop=self._loop)"
          << endl;
    }

    indent(f_service_) << "self.send_" << funname << "(";

    bool first = true;
    for (fld_iter = fields.begin(); fld_iter != fields.end(); ++fld_iter) {
      if (first) {
        first = false;
      } else {
        f_service_ << ", ";
      }
      f_service_ << rename_reserved_keywords((*fld_iter)->get_name());
    }
    f_service_ << ")" << endl;

    if ((*f_iter)->qualifier() != t_function_qualifier::one_way) {
      f_service_ << indent();
      if (gen_asyncio_) {
        f_service_ << "return fut" << endl;
      } else {
        if (!(*f_iter)->get_returntype()->is_void()) {
          f_service_ << "return ";
        }
        f_service_ << "self.recv_" << funname << "()" << endl;
      }
    } else {
      if (gen_asyncio_) {
        f_service_ << indent() << "fut.set_result(None)" << endl
                   << indent() << "return fut" << endl;
      }
    }
    indent_down();
    f_service_ << endl;

    indent(f_service_) << "def send_" << function_signature(*f_iter) << ":"
                       << endl;

    indent_up();

    // Serialize the request header
    f_service_ << indent() << "self._oprot.writeMessageBegin('"
               << (*f_iter)->get_name() << "', TMessageType.CALL, self._seqid)"
               << endl;

    f_service_ << indent() << "args = " << argsname << "()" << endl;

    for (fld_iter = fields.begin(); fld_iter != fields.end(); ++fld_iter) {
      f_service_ << indent() << "args."
                 << rename_reserved_keywords((*fld_iter)->get_name()) << " = "
                 << rename_reserved_keywords((*fld_iter)->get_name()) << endl;
    }

    std::string flush = (*f_iter)->qualifier() == t_function_qualifier::one_way
        ? "onewayFlush"
        : "flush";
    // Write to the stream
    f_service_ << indent() << "args.write(self._oprot)" << endl
               << indent() << "self._oprot.writeMessageEnd()" << endl
               << indent() << "self._oprot.trans." << flush << "()" << endl;

    indent_down();

    if ((*f_iter)->qualifier() != t_function_qualifier::one_way) {
      std::string resultname = (*f_iter)->get_name() + "_result";
      // Open function
      f_service_ << endl;
      if (gen_asyncio_) {
        f_service_ << indent() << "def recv_" << (*f_iter)->get_name()
                   << "(self, iprot, mtype, rseqid):" << endl;
      } else {
        t_function recv_function(
            (*f_iter)->get_returntype(),
            string("recv_") + (*f_iter)->get_name(),
            std::make_unique<t_paramlist>(program_));
        f_service_ << indent() << "def " << function_signature(&recv_function)
                   << ":" << endl;
      }
      indent_up();

      // TODO(mcslee): Validate message reply here, seq ids etc.

      if (gen_asyncio_) {
        f_service_ << indent() << "try:" << endl;
        f_service_ << indent() << indent() << "fut = self._futures.pop(rseqid)"
                   << endl;
        f_service_ << indent() << "except KeyError:" << endl;
        f_service_ << indent() << indent() << "return   # request timed out"
                   << endl;
      } else {
        f_service_ << indent() << "(fname, mtype, rseqid) = "
                   << "self._iprot.readMessageBegin()" << endl;
      }

      f_service_ << indent() << "if mtype == TMessageType.EXCEPTION:" << endl
                 << indent() << "  x = TApplicationException()" << endl;

      if (gen_asyncio_) {
        f_service_ << indent() << "  x.read(iprot)" << endl
                   << indent() << "  iprot.readMessageEnd()" << endl
                   << indent() << "  fut.set_exception(x)" << endl
                   << indent() << "  return" << endl
                   << indent() << "result = " << resultname << "()" << endl
                   << indent() << "try:" << endl
                   << indent() << "  result.read(iprot)" << endl
                   << indent() << "except Exception as e:" << endl
                   << indent() << "  fut.set_exception(e)" << endl
                   << indent() << "  return" << endl
                   << indent() << "iprot.readMessageEnd()" << endl;
      } else {
        f_service_ << indent() << "  x.read(self._iprot)" << endl
                   << indent() << "  self._iprot.readMessageEnd()" << endl
                   << indent() << "  raise x" << endl
                   << indent() << "result = " << resultname << "()" << endl
                   << indent() << "result.read(self._iprot)" << endl
                   << indent() << "self._iprot.readMessageEnd()" << endl;
      }

      // Careful, only return _result if not a void function
      if (!(*f_iter)->get_returntype()->is_void()) {
        f_service_ << indent() << "if result.success != None:" << endl;
        if (gen_asyncio_) {
          f_service_ << indent() << "  fut.set_result(result.success)" << endl
                     << indent() << "  return" << endl;
        } else {
          f_service_ << indent() << "  return result.success" << endl;
        }
      }

      const t_struct* xs = (*f_iter)->get_xceptions();
      const std::vector<t_field*>& xceptions = xs->get_members();
      vector<t_field*>::const_iterator x_iter;
      for (x_iter = xceptions.begin(); x_iter != xceptions.end(); ++x_iter) {
        f_service_ << indent() << "if result."
                   << rename_reserved_keywords((*x_iter)->get_name())
                   << " != None:" << endl;
        if (gen_asyncio_) {
          f_service_ << indent() << "  fut.set_exception(result."
                     << rename_reserved_keywords((*x_iter)->get_name()) << ")"
                     << endl
                     << indent() << "  return" << endl;
        } else {
          f_service_ << indent() << "  raise result."
                     << rename_reserved_keywords((*x_iter)->get_name()) << endl;
        }
      }

      // Careful, only return _result if not a void function
      if ((*f_iter)->get_returntype()->is_void()) {
        if (gen_asyncio_) {
          f_service_ << indent() << "fut.set_result(None)" << endl
                     << indent() << "return" << endl;
        } else {
          indent(f_service_) << "return" << endl;
        }
      } else {
        if (gen_asyncio_) {
          f_service_ << indent() << "fut.set_exception(TApplicationException("
                     << "TApplicationException.MISSING_RESULT, \""
                     << (*f_iter)->get_name() << " failed: unknown result\"))"
                     << endl
                     << indent() << "return" << endl;
        } else {
          f_service_ << indent() << "raise TApplicationException("
                     << "TApplicationException.MISSING_RESULT, \""
                     << (*f_iter)->get_name() << " failed: unknown result\");"
                     << endl;
        }
      }

      // Close function
      indent_down();
      f_service_ << endl;
    }
  }

  indent_down();
  f_service_ << endl;
}