in tensorflow_quantum/core/ops/math_ops/tfq_inner_product_grad.cc [51:202]
void Compute(tensorflow::OpKernelContext* context) override {
// TODO (mbbrough): add more dimension checks for other inputs here.
const int num_inputs = context->num_inputs();
OP_REQUIRES(context, num_inputs == 5,
tensorflow::errors::InvalidArgument(absl::StrCat(
"Expected 5 inputs, got ", num_inputs, " inputs.")));
// Create the output Tensor.
const int output_dim_batch_size = context->input(0).dim_size(0);
const int output_dim_internal_size = context->input(3).dim_size(1);
const int output_dim_symbol_size = context->input(1).dim_size(0);
OP_REQUIRES(context, output_dim_symbol_size > 0,
tensorflow::errors::InvalidArgument(absl::StrCat(
"The number of symbols must be a positive integer, got ",
output_dim_symbol_size, " symbols.")));
tensorflow::TensorShape output_shape;
output_shape.AddDim(output_dim_batch_size);
output_shape.AddDim(output_dim_symbol_size);
tensorflow::Tensor* output = nullptr;
OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
auto output_tensor = output->matrix<std::complex<float>>();
// Parse program protos.
std::vector<Program> programs;
std::vector<int> num_qubits;
std::vector<std::vector<Program>> other_programs;
OP_REQUIRES_OK(context,
GetProgramsAndNumQubits(context, &programs, &num_qubits,
&other_programs));
std::vector<SymbolMap> maps;
OP_REQUIRES_OK(context, GetSymbolMaps(context, &maps));
OP_REQUIRES(context, programs.size() == maps.size(),
tensorflow::errors::InvalidArgument(absl::StrCat(
"Number of circuits and symbol_values do not match. Got ",
programs.size(), " circuits and ", maps.size(),
" symbol values.")));
OP_REQUIRES(context, output_dim_symbol_size == maps[0].size(),
tensorflow::errors::InvalidArgument(absl::StrCat(
"Number of symbols and symbol maps do not match. Got ",
output_dim_symbol_size, " symbols and ", maps[0].size(),
" symbol values.")));
// Construct qsim circuits for programs.
std::vector<QsimCircuit> qsim_circuits(programs.size(), QsimCircuit());
std::vector<QsimFusedCircuit> fused_circuits(programs.size(),
QsimFusedCircuit({}));
// track metadata.
std::vector<std::vector<tfq::GateMetaData>> gate_meta(
programs.size(), std::vector<tfq::GateMetaData>({}));
// Construct qsim circuits.
std::vector<std::vector<std::vector<qsim::GateFused<QsimGate>>>>
partial_fused_circuits(
programs.size(),
std::vector<std::vector<qsim::GateFused<QsimGate>>>({}));
// track gradients
std::vector<std::vector<GradientOfGate>> gradient_gates(
programs.size(), std::vector<GradientOfGate>({}));
Status parse_status = Status::OK();
auto p_lock = tensorflow::mutex();
auto construct_f = [&](int start, int end) {
for (int i = start; i < end; i++) {
Status local = QsimCircuitFromProgram(
programs[i], maps[i], num_qubits[i], &qsim_circuits[i],
&fused_circuits[i], &gate_meta[i]);
NESTED_FN_STATUS_SYNC(parse_status, local, p_lock);
CreateGradientCircuit(qsim_circuits[i], gate_meta[i],
&partial_fused_circuits[i], &gradient_gates[i]);
}
};
const int num_cycles = 1000;
context->device()->tensorflow_cpu_worker_threads()->workers->ParallelFor(
output_dim_batch_size, num_cycles, construct_f);
OP_REQUIRES_OK(context, parse_status);
// Construct qsim circuits for other_programs.
std::vector<std::vector<QsimCircuit>> other_qsim_circuits(
output_dim_batch_size,
std::vector<QsimCircuit>(output_dim_internal_size, QsimCircuit()));
std::vector<std::vector<QsimFusedCircuit>> other_fused_circuits(
output_dim_batch_size,
std::vector<QsimFusedCircuit>(output_dim_internal_size,
QsimFusedCircuit({})));
auto construct_f2 = [&](int start, int end) {
for (int i = start; i < end; i++) {
int ii = i / output_dim_internal_size;
int jj = i % output_dim_internal_size;
Status status = QsimCircuitFromProgram(
other_programs[ii][jj], {}, num_qubits[ii],
&other_qsim_circuits[ii][jj], &other_fused_circuits[ii][jj]);
NESTED_FN_STATUS_SYNC(parse_status, status, p_lock);
}
};
context->device()->tensorflow_cpu_worker_threads()->workers->ParallelFor(
output_dim_batch_size * output_dim_internal_size, num_cycles,
construct_f2);
if (!parse_status.ok()) {
OP_REQUIRES_OK(context,
tensorflow::errors::InvalidArgument(absl::StrCat(
"Found symbols in other_programs.",
"No symbols are allowed in these circuits.")));
}
int max_num_qubits = 0;
for (const int num : num_qubits) {
max_num_qubits = std::max(max_num_qubits, num);
}
// Get downstream gradients.
std::vector<std::vector<float>> downstream_grads;
OP_REQUIRES_OK(context, GetPrevGrads(context, &downstream_grads));
OP_REQUIRES(context, downstream_grads.size() == programs.size(),
tensorflow::errors::InvalidArgument(absl::StrCat(
"Number of gradients and circuits do not match. Got ",
downstream_grads.size(), " gradients and ", programs.size(),
" circuits.")));
OP_REQUIRES(context, downstream_grads[0].size() == output_dim_internal_size,
tensorflow::errors::InvalidArgument(absl::StrCat(
"Number of gradients and other_programs do not match. Got ",
downstream_grads[0].size(), " gradient entries and ",
output_dim_internal_size, " other programs.")));
output_tensor.setZero();
// Cross reference with standard google cloud compute instances
// Memory ~= 2 * num_threads * (2 * 64 * 2 ** num_qubits in circuits)
// e2s2 = 2 CPU, 8GB -> Can safely do 23 since Memory = 4GB
// e2s4 = 4 CPU, 16GB -> Can safely do 23 since Memory = 8GB
// ...
if (max_num_qubits >= 24 || output_dim_batch_size == 1) {
ComputeLarge(num_qubits, maps, qsim_circuits, fused_circuits,
partial_fused_circuits, gradient_gates, other_fused_circuits,
downstream_grads, context, &output_tensor);
} else {
ComputeSmall(num_qubits, max_num_qubits, maps, qsim_circuits,
fused_circuits, partial_fused_circuits, gradient_gates,
other_fused_circuits, downstream_grads, context,
&output_tensor);
}
}