void Compute()

in tensorflow_quantum/core/ops/tfq_ps_decompose_op.cc [43:159]


  void Compute(tensorflow::OpKernelContext *context) override {
    std::vector<Program> programs;

    const int num_inputs = context->num_inputs();
    OP_REQUIRES(context, num_inputs == 1,
                tensorflow::errors::InvalidArgument(absl::StrCat(
                    "Expected 1 inputs, got ", num_inputs, " inputs.")));

    OP_REQUIRES_OK(context, ParsePrograms(context, "programs", &programs));

    tensorflow::Tensor *output = nullptr;
    OP_REQUIRES_OK(context, context->allocate_output(
                                0, context->input(0).shape(), &output));
    auto output_tensor = output->flat<tensorflow::tstring>();

    const int max_buffer_moments = 5;

    auto DoWork = [&](int start, int end) {
      for (int i = start; i < end; i++) {
        Program cur_program = programs.at(i);
        Program new_program;
        std::string temp;
        new_program.mutable_language()->set_gate_set("tfq_gate_set");
        new_program.mutable_circuit()->set_scheduling_strategy(
            Circuit::MOMENT_BY_MOMENT);
        for (int j = 0; j < cur_program.circuit().moments().size(); j++) {
          Moment cur_moment(cur_program.circuit().moments().at(j));
          std::vector<Moment> temp_moment_list(max_buffer_moments, Moment());
          int num_extra_moments = 0;
          for (int k = 0; k < cur_moment.operations().size(); k++) {
            Operation cur_op = cur_moment.operations().at(k);
            auto &cur_op_map = *cur_op.mutable_args();
            if (cur_op.gate().id() == "PISP") {
              auto exponent = cur_op_map.at("exponent");
              auto phase_exponent = cur_op_map.at("phase_exponent");
              if (exponent.arg_case() == Arg::ArgCase::kSymbol ||
                  phase_exponent.arg_case() == Arg::ArgCase::kSymbol) {
                // Decompose cirq.PhasedISwapPowGate only if it is
                // parameterized.
                num_extra_moments = 5;
                Operation new_op;

                new_op = getOpForPISP(cur_op, 0, 0);
                cur_moment.mutable_operations()->at(k) = new_op;
                new_op = getOpForPISP(cur_op, 1, 1);
                *temp_moment_list[0].add_operations() = new_op;
                new_op = getOpForISP(cur_op, "XXP", exponent.symbol());
                *temp_moment_list[1].add_operations() = new_op;
                new_op = getOpForISP(cur_op, "YYP", exponent.symbol());
                *temp_moment_list[2].add_operations() = new_op;
                new_op = getOpForPISP(cur_op, 1, 0);
                *temp_moment_list[3].add_operations() = new_op;
                new_op = getOpForPISP(cur_op, 0, 1);
                *temp_moment_list[4].add_operations() = new_op;
              }
            } else if (cur_op.gate().id() == "ISP") {
              auto exponent = cur_op_map.at("exponent");
              if (exponent.arg_case() == Arg::ArgCase::kSymbol) {
                // Decompose cirq.ISwapPowGate only if it is parameterized.
                if (num_extra_moments == 0) num_extra_moments = 1;
                Operation new_op;
                new_op = getOpForISP(cur_op, "XXP", exponent.symbol());
                cur_moment.mutable_operations()->at(k) = new_op;
                new_op = getOpForISP(cur_op, "YYP", exponent.symbol());
                *temp_moment_list[0].add_operations() = new_op;
              }
            } else if (cur_op.gate().id() == "PXP") {
              auto exponent = cur_op_map.at("exponent");
              auto phase_exponent = cur_op_map.at("phase_exponent");
              if (exponent.arg_case() == Arg::ArgCase::kSymbol ||
                  phase_exponent.arg_case() == Arg::ArgCase::kSymbol) {
                // Decompose cirq.PhasedXPowGate only if it is parameterized.
                num_extra_moments = 2;
                Operation new_op;
                new_op = getOpForPXP(cur_op, "ZP", "phase_exponent", true);
                cur_moment.mutable_operations()->at(k) = new_op;
                new_op = getOpForPXP(cur_op, "XP", "exponent", false);
                *temp_moment_list[0].add_operations() = new_op;
                new_op = getOpForPXP(cur_op, "ZP", "phase_exponent", false);
                *temp_moment_list[1].add_operations() = new_op;
              }
            } else if (cur_op.gate().id() == "FSIM") {
              auto theta = cur_op_map.at("theta");
              auto phi = cur_op_map.at("phi");
              if (theta.arg_case() == Arg::ArgCase::kSymbol ||
                  phi.arg_case() == Arg::ArgCase::kSymbol) {
                // Decompose cirq.FSimGate only if it is parameterized.
                num_extra_moments = 2;
                Operation new_op;
                new_op = getOpForFSIM(cur_op, "XXP", "theta", true);
                cur_moment.mutable_operations()->at(k) = new_op;
                new_op = getOpForFSIM(cur_op, "YYP", "theta", true);
                *temp_moment_list[0].add_operations() = new_op;
                new_op = getOpForFSIM(cur_op, "CZP", "phi", false);
                *temp_moment_list[1].add_operations() = new_op;
              }
            }
          }
          *new_program.mutable_circuit()->add_moments() = cur_moment;
          if (num_extra_moments > 0) {
            for (int l = 0; l < num_extra_moments; l++) {
              *new_program.mutable_circuit()->add_moments() =
                  temp_moment_list[l];
            }
          }
        }
        new_program.SerializeToString(&temp);
        output_tensor(i) = temp;
      }
    };

    const int block_size = GetBlockSize(context, programs.size());
    context->device()
        ->tensorflow_cpu_worker_threads()
        ->workers->TransformRangeConcurrently(block_size, programs.size(),
                                              DoWork);
  }