void Compute()

in tensorflow_quantum/core/ops/tfq_ps_symbol_replace_op.cc [42:180]


  void Compute(tensorflow::OpKernelContext *context) override {
    std::vector<Program> programs;

    const int num_inputs = context->num_inputs();
    OP_REQUIRES(context, num_inputs == 3,
                tensorflow::errors::InvalidArgument(absl::StrCat(
                    "Expected 3 inputs, got ", num_inputs, " inputs.")));

    OP_REQUIRES_OK(context, ParsePrograms(context, "programs", &programs));

    // Parse the input string here.
    const Tensor *symbols_tensor;
    OP_REQUIRES_OK(context, context->input("symbols", &symbols_tensor));
    OP_REQUIRES(
        context, symbols_tensor->dims() == 1,
        tensorflow::errors::InvalidArgument(absl::StrCat(
            "symbols must be rank 1. Got rank ", symbols_tensor->dims(), ".")));

    const auto symbols = symbols_tensor->vec<tensorflow::tstring>();
    const size_t n_symbols = symbols.size();

    // Parse the replacement string here.
    const Tensor *replacement_symbols_tensor;
    OP_REQUIRES_OK(context, context->input("replacement_symbols",
                                           &replacement_symbols_tensor));
    OP_REQUIRES(context, replacement_symbols_tensor->dims() == 1,
                tensorflow::errors::InvalidArgument(absl::StrCat(
                    "replacement_symbols must be rank 1. Got rank ",
                    replacement_symbols_tensor->dims(), ".")));

    const auto replacement_symbols =
        replacement_symbols_tensor->vec<tensorflow::tstring>();

    OP_REQUIRES(context, symbols.size() == replacement_symbols.size(),
                tensorflow::errors::InvalidArgument(absl::StrCat(
                    "symbols.shape is not equal to replacement_symbols.shape: ",
                    symbols.size(), " != ", replacement_symbols.size())));

    // (i,j,k) = the kth replaced program for symbols(j) in programs(i).
    std::vector<std::vector<std::vector<std::string>>> output_programs(
        programs.size(), std::vector<std::vector<std::string>>(
                             n_symbols, std::vector<std::string>()));

    auto DoWork = [&](int start, int end) {
      for (int i = start; i < end; i++) {
        int sidx = i % n_symbols;
        int pidx = i / n_symbols;
        std::string symbol_to_replace = symbols(sidx);
        std::string temp_symbol_holder;
        Program cur_program = programs.at(pidx);
        for (int j = 0; j < cur_program.circuit().moments().size(); j++) {
          Moment cur_moment = cur_program.circuit().moments().at(j);
          for (int k = 0; k < cur_moment.operations().size(); k++) {
            Operation cur_op = cur_moment.operations().at(k);
            for (auto l = cur_op.args().begin(); l != cur_op.args().end();
                 l++) {
              const std::string key = (*l).first;
              const Arg &arg = (*l).second;
              if (arg.symbol() == symbol_to_replace) {
                // Copy the proto, modify the symbol and append to output.
                Program temp(cur_program);

                // temp_symbol_holder is needed to avoid call ambiguity for
                // set_symbol below.
                temp_symbol_holder = replacement_symbols(sidx);
                temp.mutable_circuit()
                    ->mutable_moments()
                    ->at(j)
                    .mutable_operations()
                    ->at(k)
                    .mutable_args()
                    ->at(key)
                    .set_symbol(temp_symbol_holder);

                std::string res;
                temp.SerializeToString(&res);
                output_programs.at(pidx).at(sidx).push_back(res);
                temp.Clear();
              }
            }
          }
        }
      }
    };

    const int block_size = GetBlockSize(context, programs.size() * n_symbols);
    context->device()
        ->tensorflow_cpu_worker_threads()
        ->workers->TransformRangeConcurrently(
            block_size, programs.size() * n_symbols, DoWork);

    size_t biggest_pad = 0;
    Program empty = Program();
    empty.mutable_language()->set_gate_set("tfq_gate_set");
    empty.mutable_circuit();  // create empty circuits entry.

    std::string empty_program;
    empty.SerializeToString(&empty_program);

    for (size_t i = 0; i < output_programs.size(); i++) {
      for (size_t j = 0; j < n_symbols; j++) {
        biggest_pad = std::max(biggest_pad, output_programs.at(i).at(j).size());
      }
    }

    tensorflow::Tensor *output = nullptr;
    tensorflow::TensorShape output_shape;
    // batch size.
    output_shape.AddDim(programs.size());
    // entry size.
    output_shape.AddDim(n_symbols);
    output_shape.AddDim(biggest_pad);
    OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));

    auto output_tensor = output->tensor<tensorflow::tstring, 3>();

    // TODO: investigate whether or not it is worth this parallelization at the
    // end.
    //  spinning up and down parallelization for string copying might not be
    //  worth it.
    auto DoWork2 = [&](int start, int end) {
      for (int i = start; i < end; i++) {
        int sidx = i % n_symbols;
        int pidx = i / n_symbols;
        for (int j = 0; j < output_programs.at(pidx).at(sidx).size(); j++) {
          output_tensor(pidx, sidx, j) =
              output_programs.at(pidx).at(sidx).at(j);
        }
        for (int j = output_programs.at(pidx).at(sidx).size(); j < biggest_pad;
             j++) {
          output_tensor(pidx, sidx, j) = empty_program;
        }
      }
    };
    context->device()
        ->tensorflow_cpu_worker_threads()
        ->workers->TransformRangeConcurrently(
            block_size, programs.size() * n_symbols, DoWork2);
  }