in tensorflow_quantum/core/ops/tfq_adj_grad_op.cc [51:160]
void Compute(tensorflow::OpKernelContext* context) override {
// TODO (mbbrough): add more dimension checks for other inputs here.
const int num_inputs = context->num_inputs();
OP_REQUIRES(context, num_inputs == 5,
tensorflow::errors::InvalidArgument(absl::StrCat(
"Expected 5 inputs, got ", num_inputs, " inputs.")));
// Create the output Tensor.
const int output_dim_batch_size = context->input(0).dim_size(0);
const int output_dim_param_size = context->input(2).dim_size(1);
tensorflow::TensorShape output_shape;
output_shape.AddDim(output_dim_batch_size);
output_shape.AddDim(output_dim_param_size);
tensorflow::Tensor* output = nullptr;
OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
auto output_tensor = output->matrix<float>();
// Parse program protos.
std::vector<Program> programs;
std::vector<int> num_qubits;
std::vector<std::vector<PauliSum>> pauli_sums;
OP_REQUIRES_OK(context, GetProgramsAndNumQubits(context, &programs,
&num_qubits, &pauli_sums));
std::vector<SymbolMap> maps;
OP_REQUIRES_OK(context, GetSymbolMaps(context, &maps));
OP_REQUIRES(context, programs.size() == maps.size(),
tensorflow::errors::InvalidArgument(absl::StrCat(
"Number of circuits and symbol_values do not match. Got ",
programs.size(), " circuits and ", maps.size(),
" symbol values.")));
// Construct qsim circuits.
std::vector<QsimCircuit> qsim_circuits(programs.size(), QsimCircuit());
std::vector<std::vector<qsim::GateFused<QsimGate>>> full_fuse(
programs.size(), std::vector<qsim::GateFused<QsimGate>>({}));
std::vector<std::vector<std::vector<qsim::GateFused<QsimGate>>>>
partial_fused_circuits(
programs.size(),
std::vector<std::vector<qsim::GateFused<QsimGate>>>({}));
// track metadata.
std::vector<std::vector<tfq::GateMetaData>> gate_meta(
programs.size(), std::vector<tfq::GateMetaData>({}));
// track gradients
std::vector<std::vector<GradientOfGate>> gradient_gates(
programs.size(), std::vector<GradientOfGate>({}));
Status parse_status = Status::OK();
auto p_lock = tensorflow::mutex();
auto construct_f = [&](int start, int end) {
for (int i = start; i < end; i++) {
Status local = QsimCircuitFromProgram(programs[i], maps[i],
num_qubits[i], &qsim_circuits[i],
&full_fuse[i], &gate_meta[i]);
NESTED_FN_STATUS_SYNC(parse_status, local, p_lock);
CreateGradientCircuit(qsim_circuits[i], gate_meta[i],
&partial_fused_circuits[i], &gradient_gates[i]);
}
};
const int num_cycles = 1000;
context->device()->tensorflow_cpu_worker_threads()->workers->ParallelFor(
programs.size(), num_cycles, construct_f);
OP_REQUIRES_OK(context, parse_status);
// Get downstream gradients.
std::vector<std::vector<float>> downstream_grads;
OP_REQUIRES_OK(context, GetPrevGrads(context, &downstream_grads));
OP_REQUIRES(context, downstream_grads.size() == programs.size(),
tensorflow::errors::InvalidArgument(absl::StrCat(
"Number of gradients and circuits do not match. Got ",
downstream_grads.size(), " gradients and ", programs.size(),
" circuits.")));
OP_REQUIRES(
context, context->input(4).dim_size(1) == context->input(3).dim_size(1),
tensorflow::errors::InvalidArgument(absl::StrCat(
"Number of gradients and pauli sum dimension do not match. Got ",
context->input(4).dim_size(1), " gradient entries and ",
context->input(3).dim_size(1), " paulis per circuit.")));
int max_num_qubits = 0;
for (const int num : num_qubits) {
max_num_qubits = std::max(max_num_qubits, num);
}
output_tensor.setZero();
// Cross reference with standard google cloud compute instances
// Memory ~= 2 * num_threads * (2 * 64 * 2 ** num_qubits in circuits)
// e2s2 = 2 CPU, 8GB -> Can safely do 25 since Memory = 4GB
// e2s4 = 4 CPU, 16GB -> Can safely do 25 since Memory = 8GB
// ...
// This method creates 3 big state vectors per thread so reducing size
// here slightly.
if (max_num_qubits >= 25 || programs.size() == 1) {
ComputeLarge(num_qubits, qsim_circuits, maps, full_fuse,
partial_fused_circuits, pauli_sums, gradient_gates,
downstream_grads, context, &output_tensor);
} else {
ComputeSmall(num_qubits, max_num_qubits, qsim_circuits, maps, full_fuse,
partial_fused_circuits, pauli_sums, gradient_gates,
downstream_grads, context, &output_tensor);
}
}