bool AssertSameStructureHelper()

in tensorflow/tensorflow/python/util/util.cc [655:838]


bool AssertSameStructureHelper(
    PyObject* o1, PyObject* o2, bool check_types, string* error_msg,
    bool* is_type_error,
    const std::function<int(PyObject*)>& is_sequence_helper,
    const std::function<ValueIteratorPtr(PyObject*)>& value_iterator_getter,
    bool check_composite_tensor_type_spec) {
  DCHECK(error_msg);
  DCHECK(is_type_error);
  const bool is_seq1 = is_sequence_helper(o1);
  const bool is_seq2 = is_sequence_helper(o2);
  if (PyErr_Occurred()) return false;
  if (is_seq1 != is_seq2) {
    string seq_str = is_seq1 ? PyObjectToString(o1) : PyObjectToString(o2);
    string non_seq_str = is_seq1 ? PyObjectToString(o2) : PyObjectToString(o1);
    *is_type_error = false;
    *error_msg = tensorflow::strings::StrCat(
        "Substructure \"", seq_str, "\" is a sequence, while substructure \"",
        non_seq_str, "\" is not");
    return true;
  }

  // Got to objects that are considered non-sequences. Note that in tf.data
  // use case lists and sparse_tensors are not considered sequences. So finished
  // checking, structures are the same.
  if (!is_seq1) return true;

  if (check_types) {
    const PyTypeObject* type1 = o1->ob_type;
    const PyTypeObject* type2 = o2->ob_type;

    // We treat two different namedtuples with identical name and fields
    // as having the same type.
    const PyObject* o1_tuple = IsNamedtuple(o1, true);
    if (o1_tuple == nullptr) return false;
    const PyObject* o2_tuple = IsNamedtuple(o2, true);
    if (o2_tuple == nullptr) {
      Py_DECREF(o1_tuple);
      return false;
    }
    bool both_tuples = o1_tuple == Py_True && o2_tuple == Py_True;
    Py_DECREF(o1_tuple);
    Py_DECREF(o2_tuple);

    if (both_tuples) {
      const PyObject* same_tuples = SameNamedtuples(o1, o2);
      if (same_tuples == nullptr) return false;
      bool not_same_tuples = same_tuples != Py_True;
      Py_DECREF(same_tuples);
      if (not_same_tuples) {
        *is_type_error = true;
        *error_msg = tensorflow::strings::StrCat(
            "The two namedtuples don't have the same sequence type. "
            "First structure ",
            PyObjectToString(o1), " has type ", type1->tp_name,
            ", while second structure ", PyObjectToString(o2), " has type ",
            type2->tp_name);
        return true;
      }
    } else if (type1 != type2
               /* If both sequences are list types, don't complain. This allows
                  one to be a list subclass (e.g. _ListWrapper used for
                  automatic dependency tracking.) */
               && !(PyList_Check(o1) && PyList_Check(o2))
               /* Two mapping types will also compare equal, making _DictWrapper
                  and dict compare equal. */
               && !(IsMappingHelper(o1) && IsMappingHelper(o2))
               /* For CompositeTensor & TypeSpec, we check below. */
               && !(check_composite_tensor_type_spec &&
                    (IsCompositeTensor(o1) || IsCompositeTensor(o2)) &&
                    (IsTypeSpec(o1) || IsTypeSpec(o2)))) {
      *is_type_error = true;
      *error_msg = tensorflow::strings::StrCat(
          "The two namedtuples don't have the same sequence type. "
          "First structure ",
          PyObjectToString(o1), " has type ", type1->tp_name,
          ", while second structure ", PyObjectToString(o2), " has type ",
          type2->tp_name);
      return true;
    }

    if (PyDict_Check(o1) && PyDict_Check(o2)) {
      if (PyDict_Size(o1) != PyDict_Size(o2)) {
        SetDifferentKeysError(o1, o2, error_msg, is_type_error);
        return true;
      }

      PyObject* key;
      Py_ssize_t pos = 0;
      while (PyDict_Next(o1, &pos, &key, nullptr)) {
        if (PyDict_GetItem(o2, key) == nullptr) {
          SetDifferentKeysError(o1, o2, error_msg, is_type_error);
          return true;
        }
      }
    } else if (IsMappingHelper(o1)) {
      // Fallback for custom mapping types. Instead of using PyDict methods
      // which stay in C, we call iter(o1).
      if (PyMapping_Size(o1) != PyMapping_Size(o2)) {
        SetDifferentKeysError(o1, o2, error_msg, is_type_error);
        return true;
      }

      Safe_PyObjectPtr iter(PyObject_GetIter(o1));
      PyObject* key;
      while ((key = PyIter_Next(iter.get())) != nullptr) {
        if (!PyMapping_HasKey(o2, key)) {
          SetDifferentKeysError(o1, o2, error_msg, is_type_error);
          Py_DECREF(key);
          return true;
        }
        Py_DECREF(key);
      }
    }
  }

  if (check_composite_tensor_type_spec &&
      (IsCompositeTensor(o1) || IsCompositeTensor(o2))) {
    Safe_PyObjectPtr owned_type_spec_1;
    PyObject* type_spec_1 = o1;
    if (IsCompositeTensor(o1)) {
      owned_type_spec_1.reset(PyObject_GetAttrString(o1, "_type_spec"));
      type_spec_1 = owned_type_spec_1.get();
    }

    Safe_PyObjectPtr owned_type_spec_2;
    PyObject* type_spec_2 = o2;
    if (IsCompositeTensor(o2)) {
      owned_type_spec_2.reset(PyObject_GetAttrString(o2, "_type_spec"));
      type_spec_2 = owned_type_spec_2.get();
    }

    // Two composite tensors are considered to have the same structure if
    // there is some type spec that is compatible with both of them.  Thus,
    // we use most_specific_compatible_type(), and check if it raises an
    // exception.  We do *not* use is_compatible_with, since that would
    // prevent us from e.g. using a cond statement where the two sides have
    // different shapes.
    static char compatible_type[] = "most_specific_compatible_type";
    static char argspec[] = "(O)";
    Safe_PyObjectPtr struct_compatible(PyObject_CallMethod(
        type_spec_1, compatible_type, argspec, type_spec_2));
    if (PyErr_Occurred() || struct_compatible == nullptr) {
      PyErr_Clear();
      *is_type_error = false;
      *error_msg = tensorflow::strings::StrCat(
          "Incompatible CompositeTensor TypeSpecs: ",
          PyObjectToString(type_spec_1), " vs. ",
          PyObjectToString(type_spec_2));
      return true;
    }
  }

  ValueIteratorPtr iter1 = value_iterator_getter(o1);
  ValueIteratorPtr iter2 = value_iterator_getter(o2);

  if (!iter1->valid() || !iter2->valid()) return false;

  while (true) {
    Safe_PyObjectPtr v1 = iter1->next();
    Safe_PyObjectPtr v2 = iter2->next();
    if (v1 && v2) {
      if (Py_EnterRecursiveCall(" in assert_same_structure")) {
        return false;
      }
      bool no_internal_errors = AssertSameStructureHelper(
          v1.get(), v2.get(), check_types, error_msg, is_type_error,
          is_sequence_helper, value_iterator_getter,
          check_composite_tensor_type_spec);
      Py_LeaveRecursiveCall();
      if (!no_internal_errors) return false;
      if (!error_msg->empty()) return true;
    } else if (!v1 && !v2) {
      // Done with all recursive calls. Structure matched.
      return true;
    } else {
      *is_type_error = false;
      *error_msg = tensorflow::strings::StrCat(
          "The two structures don't have the same number of elements. ",
          "First structure: ", PyObjectToString(o1),
          ". Second structure: ", PyObjectToString(o2));
      return true;
    }
  }
}