void Compute()

in tensorflow_quantum/core/ops/tfq_calculate_unitary_op.cc [47:157]


  void Compute(tensorflow::OpKernelContext *context) override {
    // TODO (mbbrough): add more dimension checks for other inputs here.
    DCHECK_EQ(3, context->num_inputs());

    // Parse to Program Proto and num_qubits.
    std::vector<Program> programs;
    std::vector<int> num_qubits;
    OP_REQUIRES_OK(context,
                   GetProgramsAndNumQubits(context, &programs, &num_qubits));

    // Parse symbol maps for parameter resolution in the circuits.
    std::vector<SymbolMap> maps;
    OP_REQUIRES_OK(context, GetSymbolMaps(context, &maps));
    OP_REQUIRES(
        context, maps.size() == programs.size(),
        tensorflow::errors::InvalidArgument(absl::StrCat(
            "Number of circuits and values do not match. Got ", programs.size(),
            " circuits and ", maps.size(), " values.")));

    // Construct qsim circuits.
    std::vector<QsimCircuit> qsim_circuits(programs.size(), QsimCircuit());
    std::vector<std::vector<qsim::GateFused<QsimGate>>> fused_circuits(
        programs.size(), std::vector<qsim::GateFused<QsimGate>>({}));

    Status parse_status = Status::OK();
    auto p_lock = tensorflow::mutex();
    auto construct_f = [&](int start, int end) {
      for (int i = start; i < end; i++) {
        Status local =
            QsimCircuitFromProgram(programs[i], maps[i], num_qubits[i],
                                   &qsim_circuits[i], &fused_circuits[i]);
        NESTED_FN_STATUS_SYNC(parse_status, local, p_lock);
      }
    };

    const int num_cycles = 1000;
    context->device()->tensorflow_cpu_worker_threads()->workers->ParallelFor(
        programs.size(), num_cycles, construct_f);
    OP_REQUIRES_OK(context, parse_status);

    // Find largest circuit for tensor size padding and allocate
    // the output tensor.
    int max_num_qubits = 0;
    for (const int num : num_qubits) {
      max_num_qubits = std::max(max_num_qubits, num);
    }

    // TODO(pmassey): Investigate creating a matrix that isn't just the maximum
    // required size.
    const int output_dim_size = maps.size();
    tensorflow::TensorShape output_shape;
    output_shape.AddDim(output_dim_size);
    output_shape.AddDim(1 << max_num_qubits);
    output_shape.AddDim(1 << max_num_qubits);

    tensorflow::Tensor *output = nullptr;
    OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
    auto output_tensor = output->tensor<std::complex<float>, 3>();

    // Instantiate qsim objects.
    const auto tfq_for = tfq::QsimFor(context);
    using UCalculator = qsim::unitary::UnitaryCalculator<const tfq::QsimFor &>;
    using UnitarySpace = UCalculator::UnitarySpace;
    using Unitary = UnitarySpace::Unitary;

    // Begin simulation.
    int largest_nq = 1;
    Unitary u = UnitarySpace(tfq_for).CreateUnitary(largest_nq);

    // Simulate programs one by one. Parallelizing over state vectors
    // we no longer parallelize over circuits. Each time we encounter a
    // a larger circuit we will grow the unitary as nescessary.
    for (int i = 0; i < fused_circuits.size(); i++) {
      int nq = num_qubits[i];
      UCalculator sim = UCalculator(tfq_for);
      UnitarySpace us = UnitarySpace(tfq_for);
      if (nq > largest_nq) {
        // need to switch to larger unitaryspace.
        largest_nq = nq;
        u = us.CreateUnitary(nq);
      }
      us.SetIdentity(u);
      for (int j = 0; j < fused_circuits[i].size(); j++) {
        qsim::ApplyFusedGate(sim, fused_circuits[i][j], u);
      }

      // Parallel copy unitary information from qsim into tensorflow
      // tensors.
      auto copy_f = [i, nq, max_num_qubits, &output_tensor, &us, &u](
                        uint64_t start, uint64_t end) {
        uint64_t crossover = uint64_t(1) << nq;

        for (uint64_t l = start; l < end; l++) {
          uint64_t j = l / (1 << max_num_qubits);
          uint64_t k = l % (1 << max_num_qubits);
          if (k < crossover && j < crossover) {
            output_tensor(static_cast<ptrdiff_t>(i), static_cast<ptrdiff_t>(j),
                          static_cast<ptrdiff_t>(k)) = us.GetEntry(u, k, j);
          } else {
            output_tensor(static_cast<ptrdiff_t>(i), static_cast<ptrdiff_t>(j),
                          static_cast<ptrdiff_t>(k)) =
                std::complex<float>(-2, 0);
          }
        }
      };
      const uint64_t num_cycles_copy = 10 * (1 << max_num_qubits);
      context->device()->tensorflow_cpu_worker_threads()->workers->ParallelFor(
          (uint64_t(1) << max_num_qubits) * (uint64_t(1) << max_num_qubits),
          num_cycles_copy, copy_f);
    }
  }