in tensorflow_quantum/core/ops/math_ops/tfq_inner_product.cc [50:153]
void Compute(tensorflow::OpKernelContext* context) override {
// TODO (mbbrough): add more dimension checks for other inputs here.
const int num_inputs = context->num_inputs();
OP_REQUIRES(context, num_inputs == 4,
tensorflow::errors::InvalidArgument(absl::StrCat(
"Expected 4 inputs, got ", num_inputs, " inputs.")));
// Create the output Tensor.
const int output_dim_batch_size = context->input(0).dim_size(0);
const int output_dim_internal_size = context->input(3).dim_size(1);
tensorflow::TensorShape output_shape;
output_shape.AddDim(output_dim_batch_size);
output_shape.AddDim(output_dim_internal_size);
tensorflow::Tensor* output = nullptr;
OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
auto output_tensor = output->matrix<std::complex<float>>();
// Parse program protos.
std::vector<Program> programs;
std::vector<int> num_qubits;
std::vector<std::vector<Program>> other_programs;
OP_REQUIRES_OK(context,
GetProgramsAndNumQubits(context, &programs, &num_qubits,
&other_programs));
std::vector<SymbolMap> maps;
OP_REQUIRES_OK(context, GetSymbolMaps(context, &maps));
OP_REQUIRES(context, programs.size() == maps.size(),
tensorflow::errors::InvalidArgument(absl::StrCat(
"Number of circuits and symbol_values do not match. Got ",
programs.size(), " circuits and ", maps.size(),
" symbol values.")));
// Construct qsim circuits for programs.
std::vector<QsimCircuit> qsim_circuits(programs.size(), QsimCircuit());
std::vector<QsimFusedCircuit> fused_circuits(programs.size(),
QsimFusedCircuit({}));
Status parse_status = Status::OK();
auto p_lock = tensorflow::mutex();
auto construct_f = [&](int start, int end) {
for (int i = start; i < end; i++) {
Status local =
QsimCircuitFromProgram(programs[i], maps[i], num_qubits[i],
&qsim_circuits[i], &fused_circuits[i]);
NESTED_FN_STATUS_SYNC(parse_status, local, p_lock);
}
};
const int num_cycles = 1000;
context->device()->tensorflow_cpu_worker_threads()->workers->ParallelFor(
output_dim_batch_size, num_cycles, construct_f);
OP_REQUIRES_OK(context, parse_status);
// Construct qsim circuits for other_programs.
std::vector<std::vector<QsimCircuit>> other_qsim_circuits(
output_dim_batch_size,
std::vector<QsimCircuit>(output_dim_internal_size, QsimCircuit()));
std::vector<std::vector<QsimFusedCircuit>> other_fused_circuits(
output_dim_batch_size,
std::vector<QsimFusedCircuit>(output_dim_internal_size,
QsimFusedCircuit({})));
auto construct_f2 = [&](int start, int end) {
for (int i = start; i < end; i++) {
int ii = i / output_dim_internal_size;
int jj = i % output_dim_internal_size;
Status status = QsimCircuitFromProgram(
other_programs[ii][jj], {}, num_qubits[ii],
&other_qsim_circuits[ii][jj], &other_fused_circuits[ii][jj]);
NESTED_FN_STATUS_SYNC(parse_status, status, p_lock);
}
};
context->device()->tensorflow_cpu_worker_threads()->workers->ParallelFor(
output_dim_batch_size * output_dim_internal_size, num_cycles,
construct_f2);
if (!parse_status.ok()) {
OP_REQUIRES_OK(context,
tensorflow::errors::InvalidArgument(absl::StrCat(
"Found symbols in other_programs.",
"No symbols are allowed in these circuits.")));
}
int max_num_qubits = 0;
for (const int num : num_qubits) {
max_num_qubits = std::max(max_num_qubits, num);
}
// Cross reference with standard google cloud compute instances
// Memory ~= 2 * num_threads * (2 * 64 * 2 ** num_qubits in circuits)
// e2s2 = 2 CPU, 8GB -> Can safely do 25 since Memory = 4GB
// e2s4 = 4 CPU, 16GB -> Can safely do 25 since Memory = 8GB
// ...
if (max_num_qubits >= 26 || output_dim_batch_size == 1) {
ComputeLarge(num_qubits, fused_circuits, other_fused_circuits, context,
&output_tensor);
} else {
ComputeSmall(num_qubits, max_num_qubits, fused_circuits,
other_fused_circuits, context, &output_tensor);
}
}