in tensorflow_quantum/core/ops/tfq_ps_weights_from_symbols_op.cc [43:163]
void Compute(tensorflow::OpKernelContext *context) override {
std::vector<Program> programs;
const int num_inputs = context->num_inputs();
OP_REQUIRES(context, num_inputs == 2,
tensorflow::errors::InvalidArgument(absl::StrCat(
"Expected 2 inputs, got ", num_inputs, " inputs.")));
OP_REQUIRES_OK(context, ParsePrograms(context, "programs", &programs));
// Parse the input string here.
const Tensor *symbols_tensor;
OP_REQUIRES_OK(context, context->input("symbols", &symbols_tensor));
OP_REQUIRES(
context, symbols_tensor->dims() == 1,
tensorflow::errors::InvalidArgument(absl::StrCat(
"symbols must be rank 1. Got rank ", symbols_tensor->dims(), ".")));
const auto symbols = symbols_tensor->vec<tensorflow::tstring>();
const int n_symbols = symbols.size();
// (i,j,k) = the kth scalar value found for symbols(j) in programs(i).
std::vector<std::vector<std::vector<float>>> output_results(
programs.size(),
std::vector<std::vector<float>>(n_symbols, std::vector<float>()));
// map from symbols -> index in second dimension of output_results.
absl::flat_hash_map<std::string, int> symbols_map;
for (int i = 0; i < n_symbols; i++) {
symbols_map[symbols(i)] = i;
}
std::vector<std::string> ignore_list = {"I", "ISP", "PXP", "FSIM", "PISP",
"AD", "ADP", "DP", "GAD", "BF",
"PF", "PD", "RST"};
absl::flat_hash_set<std::string> ignored_symbol_set(ignore_list.begin(),
ignore_list.end());
std::vector<int> n_single_symbol(programs.size(), 0);
auto DoWork = [&](int start, int end) {
for (int i = start; i < end; i++) {
Program cur_program = programs.at(i);
for (int j = 0; j < cur_program.circuit().moments().size(); j++) {
Moment cur_moment = cur_program.circuit().moments().at(j);
for (int k = 0; k < cur_moment.operations().size(); k++) {
Operation cur_op = cur_moment.operations().at(k);
if (ignored_symbol_set.contains(cur_op.gate().id())) continue;
const auto &cur_op_map = *cur_op.mutable_args();
const auto exponent = cur_op_map.at("exponent");
if (exponent.arg_case() == Arg::ArgCase::kSymbol) {
// this gate has parameterized exponent.
const absl::string_view symbol_name = exponent.symbol();
if (!symbols_map.contains(symbol_name)) {
// Should never happen. raise error.
OP_REQUIRES(context, false,
tensorflow::errors::InvalidArgument(
"A circuit contains a sympy.Symbol not found "
"in symbols!"));
}
output_results.at(i)
.at(symbols_map.at(symbol_name))
.push_back(cur_op.args()
.at("exponent_scalar")
.arg_value()
.float_value());
}
}
}
// loop over all index entries of symbols_map and find largest
// value from output_results.
for (int j = 0; j < n_symbols; j++) {
n_single_symbol.at(i) =
std::max(n_single_symbol.at(i),
static_cast<int>(output_results.at(i).at(j).size()));
}
}
};
const int block_size = GetBlockSize(context, programs.size());
context->device()
->tensorflow_cpu_worker_threads()
->workers->TransformRangeConcurrently(block_size, programs.size(),
DoWork);
int largest_single_symbol = 0;
for (size_t i = 0; i < n_single_symbol.size(); i++) {
largest_single_symbol =
std::max(n_single_symbol.at(i), largest_single_symbol);
}
tensorflow::Tensor *output = nullptr;
tensorflow::TensorShape output_shape;
// batch size.
output_shape.AddDim(programs.size());
// entry size.
output_shape.AddDim(n_symbols);
output_shape.AddDim(largest_single_symbol);
OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
auto output_tensor = output->tensor<float, 3>();
auto DoWork2 = [&](int start, int end) {
for (int i = start; i < end; i++) {
for (int j = 0; j < n_symbols; j++) {
for (int k = 0; k < output_results.at(i).at(j).size(); k++) {
output_tensor(i, j, k) = output_results.at(i).at(j).at(k);
}
for (int k = output_results.at(i).at(j).size();
k < largest_single_symbol; k++) {
output_tensor(i, j, k) = 0.0f;
}
}
}
};
context->device()
->tensorflow_cpu_worker_threads()
->workers->TransformRangeConcurrently(block_size, programs.size(),
DoWork2);
}