void Compute()

in tensorflow_quantum/core/ops/tfq_ps_weights_from_symbols_op.cc [43:163]


  void Compute(tensorflow::OpKernelContext *context) override {
    std::vector<Program> programs;

    const int num_inputs = context->num_inputs();
    OP_REQUIRES(context, num_inputs == 2,
                tensorflow::errors::InvalidArgument(absl::StrCat(
                    "Expected 2 inputs, got ", num_inputs, " inputs.")));

    OP_REQUIRES_OK(context, ParsePrograms(context, "programs", &programs));

    // Parse the input string here.
    const Tensor *symbols_tensor;
    OP_REQUIRES_OK(context, context->input("symbols", &symbols_tensor));
    OP_REQUIRES(
        context, symbols_tensor->dims() == 1,
        tensorflow::errors::InvalidArgument(absl::StrCat(
            "symbols must be rank 1. Got rank ", symbols_tensor->dims(), ".")));

    const auto symbols = symbols_tensor->vec<tensorflow::tstring>();
    const int n_symbols = symbols.size();

    // (i,j,k) = the kth scalar value found for symbols(j) in programs(i).
    std::vector<std::vector<std::vector<float>>> output_results(
        programs.size(),
        std::vector<std::vector<float>>(n_symbols, std::vector<float>()));

    // map from symbols -> index in second dimension of output_results.
    absl::flat_hash_map<std::string, int> symbols_map;
    for (int i = 0; i < n_symbols; i++) {
      symbols_map[symbols(i)] = i;
    }
    std::vector<std::string> ignore_list = {"I",  "ISP", "PXP", "FSIM", "PISP",
                                            "AD", "ADP", "DP",  "GAD",  "BF",
                                            "PF", "PD",  "RST"};
    absl::flat_hash_set<std::string> ignored_symbol_set(ignore_list.begin(),
                                                        ignore_list.end());

    std::vector<int> n_single_symbol(programs.size(), 0);

    auto DoWork = [&](int start, int end) {
      for (int i = start; i < end; i++) {
        Program cur_program = programs.at(i);
        for (int j = 0; j < cur_program.circuit().moments().size(); j++) {
          Moment cur_moment = cur_program.circuit().moments().at(j);
          for (int k = 0; k < cur_moment.operations().size(); k++) {
            Operation cur_op = cur_moment.operations().at(k);
            if (ignored_symbol_set.contains(cur_op.gate().id())) continue;

            const auto &cur_op_map = *cur_op.mutable_args();
            const auto exponent = cur_op_map.at("exponent");
            if (exponent.arg_case() == Arg::ArgCase::kSymbol) {
              // this gate has parameterized exponent.
              const absl::string_view symbol_name = exponent.symbol();
              if (!symbols_map.contains(symbol_name)) {
                // Should never happen. raise error.
                OP_REQUIRES(context, false,
                            tensorflow::errors::InvalidArgument(
                                "A circuit contains a sympy.Symbol not found "
                                "in symbols!"));
              }
              output_results.at(i)
                  .at(symbols_map.at(symbol_name))
                  .push_back(cur_op.args()
                                 .at("exponent_scalar")
                                 .arg_value()
                                 .float_value());
            }
          }
        }
        // loop over all index entries of symbols_map and find largest
        // value from output_results.
        for (int j = 0; j < n_symbols; j++) {
          n_single_symbol.at(i) =
              std::max(n_single_symbol.at(i),
                       static_cast<int>(output_results.at(i).at(j).size()));
        }
      }
    };

    const int block_size = GetBlockSize(context, programs.size());
    context->device()
        ->tensorflow_cpu_worker_threads()
        ->workers->TransformRangeConcurrently(block_size, programs.size(),
                                              DoWork);

    int largest_single_symbol = 0;
    for (size_t i = 0; i < n_single_symbol.size(); i++) {
      largest_single_symbol =
          std::max(n_single_symbol.at(i), largest_single_symbol);
    }

    tensorflow::Tensor *output = nullptr;
    tensorflow::TensorShape output_shape;
    // batch size.
    output_shape.AddDim(programs.size());
    // entry size.
    output_shape.AddDim(n_symbols);
    output_shape.AddDim(largest_single_symbol);

    OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));

    auto output_tensor = output->tensor<float, 3>();

    auto DoWork2 = [&](int start, int end) {
      for (int i = start; i < end; i++) {
        for (int j = 0; j < n_symbols; j++) {
          for (int k = 0; k < output_results.at(i).at(j).size(); k++) {
            output_tensor(i, j, k) = output_results.at(i).at(j).at(k);
          }
          for (int k = output_results.at(i).at(j).size();
               k < largest_single_symbol; k++) {
            output_tensor(i, j, k) = 0.0f;
          }
        }
      }
    };
    context->device()
        ->tensorflow_cpu_worker_threads()
        ->workers->TransformRangeConcurrently(block_size, programs.size(),
                                              DoWork2);
  }