void t_py_generator::generate_py_union()

in thrift/compiler/generate/t_py_generator.cc [1243:1447]


void t_py_generator::generate_py_union(ofstream& out, const t_struct* tstruct) {
  const vector<t_field*>& members = tstruct->get_members();
  const vector<t_field*>& sorted_members = tstruct->get_sorted_members();

  out << "class " << rename_reserved_keywords(tstruct->get_name())
      << "(object):" << endl;

  indent_up();
  generate_python_docstring(out, tstruct);

  out << endl;

  indent(out) << "thrift_spec = None" << endl;
  if (members.size() != 0) {
    indent(out) << "__init__ = None" << endl << endl;
  }

  // Generate some class level identifiers (similar to enum)
  indent(out) << "__EMPTY__ = 0" << endl;
  for (auto& member : sorted_members) {
    indent(out) << uppercase(member->get_name()) << " = " << member->get_key()
                << endl;
  }
  indent(out) << endl;

  // Generate `isUnion` method
  indent(out) << "@staticmethod" << endl;
  indent(out) << "def isUnion():" << endl;
  indent(out) << "  return True" << endl << endl;

  // Generate `get_` methods
  for (auto& member : members) {
    indent(out) << "def get_" << member->get_name() << "(self):" << endl;
    indent(out) << "  assert self.field == " << member->get_key() << endl;
    indent(out) << "  return self.value" << endl << endl;
  }

  // Generate `set_` methods
  for (auto& member : members) {
    indent(out) << "def set_" << member->get_name() << "(self, value):" << endl;
    indent(out) << "  self.field = " << member->get_key() << endl;
    indent(out) << "  self.value = value" << endl << endl;
  }

  // Method to get the stored type
  indent(out) << "def getType(self):" << endl;
  indent(out) << "  return self.field" << endl << endl;

  // According to Python doc, __repr__() "should" return a valid expression
  // such that `object == eval(repr(object))` is true.
  out << indent() << "def __repr__(self):" << endl
      << indent() << "  value = pprint.pformat(self.value)" << endl
      << indent() << "  member = ''" << endl;
  for (auto& member : sorted_members) {
    auto key = rename_reserved_keywords(member->get_name());
    out << indent() << "  if self.field == " << member->get_key() << ":" << endl
        << indent() << "    padding = ' ' * " << key.size() + 1 << endl
        << indent() << "    value = padding.join(value.splitlines(True))"
        << endl
        << indent() << "    member = '\\n    %s=%s' % ('" << key << "', value)"
        << endl;
  }
  // This will generate
  //   UnionClass()  or
  //   UnionClass(
  //       key=value)
  out << indent() << "  return \"%s(%s)\" % (self.__class__.__name__, member)"
      << endl
      << endl;

  // Generate `read` method
  indent(out) << "def read(self, iprot):" << endl;
  indent_up();

  indent(out) << "self.field = 0" << endl;
  indent(out) << "self.value = None" << endl;

  generate_fastproto_read(out, tstruct);

  indent(out) << "iprot.readStructBegin()" << endl;
  indent(out) << "while True:" << endl;
  indent_up();
  indent(out) << "(fname, ftype, fid) = iprot.readFieldBegin()" << endl;
  indent(out) << "if ftype == TType.STOP:" << endl;
  indent_up();
  indent(out) << "break" << endl << endl;
  indent_down();

  bool first = true;
  for (auto& member : sorted_members) {
    auto t = type_to_enum(member->get_type());
    auto n = member->get_name();
    auto k = member->get_key();
    indent(out) << (first ? "" : "el") << "if fid == " << k << ":" << endl;
    indent_up();
    indent(out) << "if ftype == " << t << ":" << endl;
    indent_up();
    generate_deserialize_field(out, member, "");
    indent(out) << "assert self.field == 0 and self.value is None" << endl;
    indent(out) << "self.set_" << n << "(" << rename_reserved_keywords(n) << ")"
                << endl;
    indent_down();
    indent(out) << "else:" << endl;
    indent(out) << "  iprot.skip(ftype)" << endl;
    indent_down();

    first = false;
  }

  indent(out) << "else:" << endl;
  indent(out) << "  iprot.skip(ftype)" << endl;
  indent(out) << "iprot.readFieldEnd()" << endl;
  indent_down();

  indent(out) << "iprot.readStructEnd()" << endl << endl;
  indent_down();

  // Generate `write` method
  indent(out) << "def write(self, oprot):" << endl;
  indent_up();

  generate_fastproto_write(out, tstruct);

  indent(out) << "oprot.writeUnionBegin('" << tstruct->get_name() << "')"
              << endl;

  first = true;
  for (auto& member : sorted_members) {
    auto t = type_to_enum(member->get_type());
    auto n = member->get_name();
    auto k = member->get_key();

    indent(out) << (first ? "" : "el") << "if self.field == " << k << ":"
                << endl;
    indent_up();
    indent(out) << "oprot.writeFieldBegin('" << n << "', " << t << ", " << k
                << ")" << endl;

    indent(out) << rename_reserved_keywords(n) << " = self.value" << endl;
    generate_serialize_field(out, member, "");
    indent(out) << "oprot.writeFieldEnd()" << endl;
    indent_down();
  }
  indent(out) << "oprot.writeFieldStop()" << endl;
  indent(out) << "oprot.writeUnionEnd()" << endl;
  indent_down();
  indent(out) << endl;

  // Generate json reader
  if (gen_json_) {
    generate_json_reader_fn_signature(out);
    indent(out) << "self.field = 0" << endl;
    indent(out) << "self.value = None" << endl;
    indent(out) << "obj = json" << endl;
    indent(out) << "if is_text:" << endl;
    indent_up();
    indent(out) << "obj = loads(json)" << endl;
    indent_down();

    indent(out) << "if not isinstance(obj, dict) or len(obj) > 1:" << endl;
    indent(out) << "  raise TProtocolException("
                << "TProtocolException.INVALID_DATA, 'Can not parse')" << endl;
    indent(out) << endl;

    for (auto& member : members) {
      auto n = member->get_name();
      indent(out) << "if '" << n << "' in obj:" << endl;
      indent_up();
      generate_json_field(out, member, "", "", "obj['" + n + "']");
      indent(out) << "self.set_" << n << "(" << rename_reserved_keywords(n)
                  << ")" << endl;
      indent_down();
    }
    indent_down();
    out << endl;
  }

  // Equality and inequality methods that compare by value
  out << indent() << "def __eq__(self, other):" << endl;
  indent_up();
  out << indent() << "if not isinstance(other, self.__class__):" << endl;
  indent_up();
  out << indent() << "return False" << endl;
  indent_down();
  out << endl;
  if (compare_t_fields_only_) {
    out << indent() << "return "
        << "self.field == other.field and "
        << "self.value == other.value" << endl;
  } else {
    out << indent() << "return "
        << "self.__dict__ == other.__dict__" << endl;
  }

  indent_down();
  out << endl;

  out << indent() << "def __ne__(self, other):" << endl;
  indent_up();
  out << indent() << "return not (self == other)" << endl;
  indent_down();
  out << endl;

  indent_down();
}