void Compute()

in tensorflow_quantum/core/ops/math_ops/tfq_inner_product_grad.cc [51:202]


  void Compute(tensorflow::OpKernelContext* context) override {
    // TODO (mbbrough): add more dimension checks for other inputs here.
    const int num_inputs = context->num_inputs();
    OP_REQUIRES(context, num_inputs == 5,
                tensorflow::errors::InvalidArgument(absl::StrCat(
                    "Expected 5 inputs, got ", num_inputs, " inputs.")));

    // Create the output Tensor.
    const int output_dim_batch_size = context->input(0).dim_size(0);
    const int output_dim_internal_size = context->input(3).dim_size(1);
    const int output_dim_symbol_size = context->input(1).dim_size(0);
    OP_REQUIRES(context, output_dim_symbol_size > 0,
                tensorflow::errors::InvalidArgument(absl::StrCat(
                    "The number of symbols must be a positive integer, got ",
                    output_dim_symbol_size, " symbols.")));
    tensorflow::TensorShape output_shape;
    output_shape.AddDim(output_dim_batch_size);
    output_shape.AddDim(output_dim_symbol_size);

    tensorflow::Tensor* output = nullptr;
    OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
    auto output_tensor = output->matrix<std::complex<float>>();

    // Parse program protos.
    std::vector<Program> programs;
    std::vector<int> num_qubits;
    std::vector<std::vector<Program>> other_programs;
    OP_REQUIRES_OK(context,
                   GetProgramsAndNumQubits(context, &programs, &num_qubits,
                                           &other_programs));

    std::vector<SymbolMap> maps;
    OP_REQUIRES_OK(context, GetSymbolMaps(context, &maps));

    OP_REQUIRES(context, programs.size() == maps.size(),
                tensorflow::errors::InvalidArgument(absl::StrCat(
                    "Number of circuits and symbol_values do not match. Got ",
                    programs.size(), " circuits and ", maps.size(),
                    " symbol values.")));
    OP_REQUIRES(context, output_dim_symbol_size == maps[0].size(),
                tensorflow::errors::InvalidArgument(absl::StrCat(
                    "Number of symbols and symbol maps do not match. Got ",
                    output_dim_symbol_size, " symbols and ", maps[0].size(),
                    " symbol values.")));

    // Construct qsim circuits for programs.
    std::vector<QsimCircuit> qsim_circuits(programs.size(), QsimCircuit());
    std::vector<QsimFusedCircuit> fused_circuits(programs.size(),
                                                 QsimFusedCircuit({}));

    // track metadata.
    std::vector<std::vector<tfq::GateMetaData>> gate_meta(
        programs.size(), std::vector<tfq::GateMetaData>({}));

    // Construct qsim circuits.
    std::vector<std::vector<std::vector<qsim::GateFused<QsimGate>>>>
        partial_fused_circuits(
            programs.size(),
            std::vector<std::vector<qsim::GateFused<QsimGate>>>({}));

    // track gradients
    std::vector<std::vector<GradientOfGate>> gradient_gates(
        programs.size(), std::vector<GradientOfGate>({}));

    Status parse_status = Status::OK();
    auto p_lock = tensorflow::mutex();
    auto construct_f = [&](int start, int end) {
      for (int i = start; i < end; i++) {
        Status local = QsimCircuitFromProgram(
            programs[i], maps[i], num_qubits[i], &qsim_circuits[i],
            &fused_circuits[i], &gate_meta[i]);
        NESTED_FN_STATUS_SYNC(parse_status, local, p_lock);

        CreateGradientCircuit(qsim_circuits[i], gate_meta[i],
                              &partial_fused_circuits[i], &gradient_gates[i]);
      }
    };

    const int num_cycles = 1000;
    context->device()->tensorflow_cpu_worker_threads()->workers->ParallelFor(
        output_dim_batch_size, num_cycles, construct_f);
    OP_REQUIRES_OK(context, parse_status);

    // Construct qsim circuits for other_programs.
    std::vector<std::vector<QsimCircuit>> other_qsim_circuits(
        output_dim_batch_size,
        std::vector<QsimCircuit>(output_dim_internal_size, QsimCircuit()));
    std::vector<std::vector<QsimFusedCircuit>> other_fused_circuits(
        output_dim_batch_size,
        std::vector<QsimFusedCircuit>(output_dim_internal_size,
                                      QsimFusedCircuit({})));

    auto construct_f2 = [&](int start, int end) {
      for (int i = start; i < end; i++) {
        int ii = i / output_dim_internal_size;
        int jj = i % output_dim_internal_size;
        Status status = QsimCircuitFromProgram(
            other_programs[ii][jj], {}, num_qubits[ii],
            &other_qsim_circuits[ii][jj], &other_fused_circuits[ii][jj]);
        NESTED_FN_STATUS_SYNC(parse_status, status, p_lock);
      }
    };

    context->device()->tensorflow_cpu_worker_threads()->workers->ParallelFor(
        output_dim_batch_size * output_dim_internal_size, num_cycles,
        construct_f2);
    if (!parse_status.ok()) {
      OP_REQUIRES_OK(context,
                     tensorflow::errors::InvalidArgument(absl::StrCat(
                         "Found symbols in other_programs.",
                         "No symbols are allowed in these circuits.")));
    }

    int max_num_qubits = 0;
    for (const int num : num_qubits) {
      max_num_qubits = std::max(max_num_qubits, num);
    }

    // Get downstream gradients.
    std::vector<std::vector<float>> downstream_grads;
    OP_REQUIRES_OK(context, GetPrevGrads(context, &downstream_grads));

    OP_REQUIRES(context, downstream_grads.size() == programs.size(),
                tensorflow::errors::InvalidArgument(absl::StrCat(
                    "Number of gradients and circuits do not match. Got ",
                    downstream_grads.size(), " gradients and ", programs.size(),
                    " circuits.")));

    OP_REQUIRES(context, downstream_grads[0].size() == output_dim_internal_size,
                tensorflow::errors::InvalidArgument(absl::StrCat(
                    "Number of gradients and other_programs do not match. Got ",
                    downstream_grads[0].size(), " gradient entries and ",
                    output_dim_internal_size, " other programs.")));

    output_tensor.setZero();

    // Cross reference with standard google cloud compute instances
    // Memory ~= 2 * num_threads * (2 * 64 * 2 ** num_qubits in circuits)
    // e2s2 = 2 CPU, 8GB -> Can safely do 23 since Memory = 4GB
    // e2s4 = 4 CPU, 16GB -> Can safely do 23 since Memory = 8GB
    // ...
    if (max_num_qubits >= 24 || output_dim_batch_size == 1) {
      ComputeLarge(num_qubits, maps, qsim_circuits, fused_circuits,
                   partial_fused_circuits, gradient_gates, other_fused_circuits,
                   downstream_grads, context, &output_tensor);
    } else {
      ComputeSmall(num_qubits, max_num_qubits, maps, qsim_circuits,
                   fused_circuits, partial_fused_circuits, gradient_gates,
                   other_fused_circuits, downstream_grads, context,
                   &output_tensor);
    }
  }