Status ResolveQubitIds()

in tensorflow_quantum/core/src/program_resolution.cc [182:303]


Status ResolveQubitIds(Program* program, unsigned int* num_qubits,
                       std::vector<Program>* other_programs) {
  if (program->circuit().moments().empty()) {
    // (#679) Just ignore empty program.
    // Number of qubits in empty programs is zero.
    *num_qubits = 0;
    return Status::OK();
  }

  absl::flat_hash_set<std::pair<std::pair<int, int>, std::string>> id_set;
  for (const Moment& moment : program->circuit().moments()) {
    for (const Operation& operation : moment.operations()) {
      Status s;
      for (const Qubit& qubit : operation.qubits()) {
        s = RegisterQubits(qubit.id(), &id_set);
        if (!s.ok()) {
          return s;
        }
      }
      s = RegisterQubits(
          operation.args().at("control_qubits").arg_value().string_value(),
          &id_set);
      if (!s.ok()) {
        return s;
      }
    }
  }
  *num_qubits = id_set.size();

  // call to std::sort will do (r1 < r2) || ((r1 == r2) && c1 < c2)
  std::vector<std::pair<std::pair<int, int>, std::string>> ids(id_set.begin(),
                                                               id_set.end());
  std::sort(ids.begin(), ids.end());

  absl::flat_hash_map<std::string, std::string> id_to_index;
  absl::flat_hash_set<std::string> id_ref;
  for (size_t i = 0; i < ids.size(); i++) {
    id_to_index[ids[i].second] = absl::StrCat(i);
    id_ref.insert(ids[i].second);
  }

  // Replace the Program Qubit ids with the indices.
  for (Moment& moment : *program->mutable_circuit()->mutable_moments()) {
    for (Operation& operation : *moment.mutable_operations()) {
      for (Qubit& qubit : *operation.mutable_qubits()) {
        qubit.set_id(id_to_index.at(qubit.id()));
      }
      // Resolve control qubit ids found in the control_qubits arg.
      absl::string_view control_qubits =
          operation.args().at("control_qubits").arg_value().string_value();
      // explicit empty value set in serializer.py.
      if (control_qubits.empty()) {
        continue;
      }
      std::vector<absl::string_view> control_ids =
          absl::StrSplit(control_qubits, ',');
      std::vector<std::string> control_indexs;
      control_indexs.reserve(control_ids.size());
      for (auto id : control_ids) {
        control_indexs.push_back(id_to_index.at(id));
      }
      operation.mutable_args()
          ->at("control_qubits")
          .mutable_arg_value()
          ->set_string_value(absl::StrJoin(control_indexs, ","));
    }
  }

  for (size_t i = 0; i < other_programs->size(); i++) {
    // Replace the other_program Qubit ids with the indices.
    absl::flat_hash_set<std::string> visited_qubits(id_ref);
    for (Moment& moment :
         *(other_programs->at(i)).mutable_circuit()->mutable_moments()) {
      for (Operation& operation : *moment.mutable_operations()) {
        // Resolve qubit ids.
        for (Qubit& qubit : *operation.mutable_qubits()) {
          visited_qubits.erase(qubit.id());
          const auto result = id_to_index.find(qubit.id());
          if (result == id_to_index.end()) {
            return Status(tensorflow::error::INVALID_ARGUMENT,
                          "A paired circuit contains qubits not found in "
                          "reference circuit.");
          }
          qubit.set_id(result->second);
        }
        // Resolve control qubit ids.
        absl::string_view control_qubits = operation.mutable_args()
                                               ->at("control_qubits")
                                               .arg_value()
                                               .string_value();
        if (control_qubits.empty()) {  // explicit empty value.
          continue;
        }
        std::vector<absl::string_view> control_ids =
            absl::StrSplit(control_qubits, ',');
        std::vector<std::string> control_indexs;
        control_indexs.reserve(control_ids.size());
        for (auto id : control_ids) {
          visited_qubits.erase(id);
          const auto result = id_to_index.find(id);
          if (result == id_to_index.end()) {
            return Status(tensorflow::error::INVALID_ARGUMENT,
                          "A paired circuit contains qubits not found in "
                          "reference circuit.");
          }
          control_indexs.push_back(result->second);
        }
        operation.mutable_args()
            ->at("control_qubits")
            .mutable_arg_value()
            ->set_string_value(absl::StrJoin(control_indexs, ","));
      }
    }
    if (!visited_qubits.empty()) {
      return Status(
          tensorflow::error::INVALID_ARGUMENT,
          "A reference circuit contains qubits not found in paired circuit.");
    }
  }

  return Status::OK();
}