in tensorflow_quantum/core/ops/tfq_ps_decompose_op.cc [43:159]
void Compute(tensorflow::OpKernelContext *context) override {
std::vector<Program> programs;
const int num_inputs = context->num_inputs();
OP_REQUIRES(context, num_inputs == 1,
tensorflow::errors::InvalidArgument(absl::StrCat(
"Expected 1 inputs, got ", num_inputs, " inputs.")));
OP_REQUIRES_OK(context, ParsePrograms(context, "programs", &programs));
tensorflow::Tensor *output = nullptr;
OP_REQUIRES_OK(context, context->allocate_output(
0, context->input(0).shape(), &output));
auto output_tensor = output->flat<tensorflow::tstring>();
const int max_buffer_moments = 5;
auto DoWork = [&](int start, int end) {
for (int i = start; i < end; i++) {
Program cur_program = programs.at(i);
Program new_program;
std::string temp;
new_program.mutable_language()->set_gate_set("tfq_gate_set");
new_program.mutable_circuit()->set_scheduling_strategy(
Circuit::MOMENT_BY_MOMENT);
for (int j = 0; j < cur_program.circuit().moments().size(); j++) {
Moment cur_moment(cur_program.circuit().moments().at(j));
std::vector<Moment> temp_moment_list(max_buffer_moments, Moment());
int num_extra_moments = 0;
for (int k = 0; k < cur_moment.operations().size(); k++) {
Operation cur_op = cur_moment.operations().at(k);
auto &cur_op_map = *cur_op.mutable_args();
if (cur_op.gate().id() == "PISP") {
auto exponent = cur_op_map.at("exponent");
auto phase_exponent = cur_op_map.at("phase_exponent");
if (exponent.arg_case() == Arg::ArgCase::kSymbol ||
phase_exponent.arg_case() == Arg::ArgCase::kSymbol) {
// Decompose cirq.PhasedISwapPowGate only if it is
// parameterized.
num_extra_moments = 5;
Operation new_op;
new_op = getOpForPISP(cur_op, 0, 0);
cur_moment.mutable_operations()->at(k) = new_op;
new_op = getOpForPISP(cur_op, 1, 1);
*temp_moment_list[0].add_operations() = new_op;
new_op = getOpForISP(cur_op, "XXP", exponent.symbol());
*temp_moment_list[1].add_operations() = new_op;
new_op = getOpForISP(cur_op, "YYP", exponent.symbol());
*temp_moment_list[2].add_operations() = new_op;
new_op = getOpForPISP(cur_op, 1, 0);
*temp_moment_list[3].add_operations() = new_op;
new_op = getOpForPISP(cur_op, 0, 1);
*temp_moment_list[4].add_operations() = new_op;
}
} else if (cur_op.gate().id() == "ISP") {
auto exponent = cur_op_map.at("exponent");
if (exponent.arg_case() == Arg::ArgCase::kSymbol) {
// Decompose cirq.ISwapPowGate only if it is parameterized.
if (num_extra_moments == 0) num_extra_moments = 1;
Operation new_op;
new_op = getOpForISP(cur_op, "XXP", exponent.symbol());
cur_moment.mutable_operations()->at(k) = new_op;
new_op = getOpForISP(cur_op, "YYP", exponent.symbol());
*temp_moment_list[0].add_operations() = new_op;
}
} else if (cur_op.gate().id() == "PXP") {
auto exponent = cur_op_map.at("exponent");
auto phase_exponent = cur_op_map.at("phase_exponent");
if (exponent.arg_case() == Arg::ArgCase::kSymbol ||
phase_exponent.arg_case() == Arg::ArgCase::kSymbol) {
// Decompose cirq.PhasedXPowGate only if it is parameterized.
num_extra_moments = 2;
Operation new_op;
new_op = getOpForPXP(cur_op, "ZP", "phase_exponent", true);
cur_moment.mutable_operations()->at(k) = new_op;
new_op = getOpForPXP(cur_op, "XP", "exponent", false);
*temp_moment_list[0].add_operations() = new_op;
new_op = getOpForPXP(cur_op, "ZP", "phase_exponent", false);
*temp_moment_list[1].add_operations() = new_op;
}
} else if (cur_op.gate().id() == "FSIM") {
auto theta = cur_op_map.at("theta");
auto phi = cur_op_map.at("phi");
if (theta.arg_case() == Arg::ArgCase::kSymbol ||
phi.arg_case() == Arg::ArgCase::kSymbol) {
// Decompose cirq.FSimGate only if it is parameterized.
num_extra_moments = 2;
Operation new_op;
new_op = getOpForFSIM(cur_op, "XXP", "theta", true);
cur_moment.mutable_operations()->at(k) = new_op;
new_op = getOpForFSIM(cur_op, "YYP", "theta", true);
*temp_moment_list[0].add_operations() = new_op;
new_op = getOpForFSIM(cur_op, "CZP", "phi", false);
*temp_moment_list[1].add_operations() = new_op;
}
}
}
*new_program.mutable_circuit()->add_moments() = cur_moment;
if (num_extra_moments > 0) {
for (int l = 0; l < num_extra_moments; l++) {
*new_program.mutable_circuit()->add_moments() =
temp_moment_list[l];
}
}
}
new_program.SerializeToString(&temp);
output_tensor(i) = temp;
}
};
const int block_size = GetBlockSize(context, programs.size());
context->device()
->tensorflow_cpu_worker_threads()
->workers->TransformRangeConcurrently(block_size, programs.size(),
DoWork);
}