in tensorflow_quantum/core/ops/tfq_ps_symbol_replace_op.cc [42:180]
void Compute(tensorflow::OpKernelContext *context) override {
std::vector<Program> programs;
const int num_inputs = context->num_inputs();
OP_REQUIRES(context, num_inputs == 3,
tensorflow::errors::InvalidArgument(absl::StrCat(
"Expected 3 inputs, got ", num_inputs, " inputs.")));
OP_REQUIRES_OK(context, ParsePrograms(context, "programs", &programs));
// Parse the input string here.
const Tensor *symbols_tensor;
OP_REQUIRES_OK(context, context->input("symbols", &symbols_tensor));
OP_REQUIRES(
context, symbols_tensor->dims() == 1,
tensorflow::errors::InvalidArgument(absl::StrCat(
"symbols must be rank 1. Got rank ", symbols_tensor->dims(), ".")));
const auto symbols = symbols_tensor->vec<tensorflow::tstring>();
const size_t n_symbols = symbols.size();
// Parse the replacement string here.
const Tensor *replacement_symbols_tensor;
OP_REQUIRES_OK(context, context->input("replacement_symbols",
&replacement_symbols_tensor));
OP_REQUIRES(context, replacement_symbols_tensor->dims() == 1,
tensorflow::errors::InvalidArgument(absl::StrCat(
"replacement_symbols must be rank 1. Got rank ",
replacement_symbols_tensor->dims(), ".")));
const auto replacement_symbols =
replacement_symbols_tensor->vec<tensorflow::tstring>();
OP_REQUIRES(context, symbols.size() == replacement_symbols.size(),
tensorflow::errors::InvalidArgument(absl::StrCat(
"symbols.shape is not equal to replacement_symbols.shape: ",
symbols.size(), " != ", replacement_symbols.size())));
// (i,j,k) = the kth replaced program for symbols(j) in programs(i).
std::vector<std::vector<std::vector<std::string>>> output_programs(
programs.size(), std::vector<std::vector<std::string>>(
n_symbols, std::vector<std::string>()));
auto DoWork = [&](int start, int end) {
for (int i = start; i < end; i++) {
int sidx = i % n_symbols;
int pidx = i / n_symbols;
std::string symbol_to_replace = symbols(sidx);
std::string temp_symbol_holder;
Program cur_program = programs.at(pidx);
for (int j = 0; j < cur_program.circuit().moments().size(); j++) {
Moment cur_moment = cur_program.circuit().moments().at(j);
for (int k = 0; k < cur_moment.operations().size(); k++) {
Operation cur_op = cur_moment.operations().at(k);
for (auto l = cur_op.args().begin(); l != cur_op.args().end();
l++) {
const std::string key = (*l).first;
const Arg &arg = (*l).second;
if (arg.symbol() == symbol_to_replace) {
// Copy the proto, modify the symbol and append to output.
Program temp(cur_program);
// temp_symbol_holder is needed to avoid call ambiguity for
// set_symbol below.
temp_symbol_holder = replacement_symbols(sidx);
temp.mutable_circuit()
->mutable_moments()
->at(j)
.mutable_operations()
->at(k)
.mutable_args()
->at(key)
.set_symbol(temp_symbol_holder);
std::string res;
temp.SerializeToString(&res);
output_programs.at(pidx).at(sidx).push_back(res);
temp.Clear();
}
}
}
}
}
};
const int block_size = GetBlockSize(context, programs.size() * n_symbols);
context->device()
->tensorflow_cpu_worker_threads()
->workers->TransformRangeConcurrently(
block_size, programs.size() * n_symbols, DoWork);
size_t biggest_pad = 0;
Program empty = Program();
empty.mutable_language()->set_gate_set("tfq_gate_set");
empty.mutable_circuit(); // create empty circuits entry.
std::string empty_program;
empty.SerializeToString(&empty_program);
for (size_t i = 0; i < output_programs.size(); i++) {
for (size_t j = 0; j < n_symbols; j++) {
biggest_pad = std::max(biggest_pad, output_programs.at(i).at(j).size());
}
}
tensorflow::Tensor *output = nullptr;
tensorflow::TensorShape output_shape;
// batch size.
output_shape.AddDim(programs.size());
// entry size.
output_shape.AddDim(n_symbols);
output_shape.AddDim(biggest_pad);
OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
auto output_tensor = output->tensor<tensorflow::tstring, 3>();
// TODO: investigate whether or not it is worth this parallelization at the
// end.
// spinning up and down parallelization for string copying might not be
// worth it.
auto DoWork2 = [&](int start, int end) {
for (int i = start; i < end; i++) {
int sidx = i % n_symbols;
int pidx = i / n_symbols;
for (int j = 0; j < output_programs.at(pidx).at(sidx).size(); j++) {
output_tensor(pidx, sidx, j) =
output_programs.at(pidx).at(sidx).at(j);
}
for (int j = output_programs.at(pidx).at(sidx).size(); j < biggest_pad;
j++) {
output_tensor(pidx, sidx, j) = empty_program;
}
}
};
context->device()
->tensorflow_cpu_worker_threads()
->workers->TransformRangeConcurrently(
block_size, programs.size() * n_symbols, DoWork2);
}