in tensorflow_quantum/core/ops/noise/tfq_noisy_expectation.cc [252:385]
void ComputeSmall(const std::vector<int>& num_qubits,
const int max_num_qubits,
const std::vector<NoisyQsimCircuit>& ncircuits,
const std::vector<std::vector<PauliSum>>& pauli_sums,
const std::vector<std::vector<int>>& num_samples,
tensorflow::OpKernelContext* context,
tensorflow::TTypes<float, 1>::Matrix* output_tensor) {
using Simulator = qsim::Simulator<const qsim::SequentialFor&>;
using StateSpace = Simulator::StateSpace;
using QTSimulator =
qsim::QuantumTrajectorySimulator<qsim::IO, QsimGate,
qsim::MultiQubitGateFuser, Simulator>;
const int output_dim_batch_size = output_tensor->dimension(0);
std::vector<tensorflow::mutex> batch_locks(output_dim_batch_size,
tensorflow::mutex());
const int num_threads = context->device()
->tensorflow_cpu_worker_threads()
->workers->NumThreads();
// [num_threads, batch_size].
std::vector<std::vector<int>> rep_offsets(
num_threads, std::vector<int>(output_dim_batch_size, 0));
BalanceTrajectory(num_samples, num_threads, &rep_offsets);
output_tensor->setZero();
tensorflow::GuardedPhiloxRandom random_gen;
int max_n_shots = 1;
for (int i = 0; i < num_samples.size(); i++) {
for (int j = 0; j < num_samples[i].size(); j++) {
max_n_shots = std::max(max_n_shots, num_samples[i][j]);
}
}
random_gen.Init(tensorflow::random::New64(), tensorflow::random::New64());
Status compute_status = Status::OK();
auto c_lock = tensorflow::mutex();
auto DoWork = [&](int start, int end) {
// Begin simulation.
const auto tfq_for = qsim::SequentialFor(1);
int largest_nq = 1;
Simulator sim = Simulator(tfq_for);
StateSpace ss = StateSpace(tfq_for);
auto sv = ss.Create(largest_nq);
auto scratch = ss.Create(largest_nq);
int n_rand = ncircuits.size() * max_n_shots + 1;
n_rand = (n_rand + num_threads) / num_threads;
auto local_gen =
random_gen.ReserveSamples128(ncircuits.size() * max_n_shots + 1);
tensorflow::random::SimplePhilox rand_source(&local_gen);
for (int i = 0; i < ncircuits.size(); i++) {
int nq = num_qubits[i];
int rep_offset = rep_offsets[start][i];
// (#679) Just ignore empty program
if (ncircuits[i].channels.size() == 0) {
for (int j = 0; j < pauli_sums[i].size(); j++) {
(*output_tensor)(i, j) = -2.0;
}
continue;
}
if (nq > largest_nq) {
largest_nq = nq;
sv = ss.Create(largest_nq);
scratch = ss.Create(largest_nq);
}
QTSimulator::Parameter param;
param.collect_kop_stat = false;
param.collect_mea_stat = false;
param.normalize_before_mea_gates = true;
std::vector<uint64_t> unused_stats;
// Track op-wise stats.
std::vector<int> run_samples(num_samples[i].size(), 0);
std::vector<double> rolling_sums(num_samples[i].size(), 0.0);
while (1) {
ss.SetStateZero(sv);
QTSimulator::RunOnce(param, ncircuits[i], rand_source.Rand64(), ss,
sim, scratch, sv, unused_stats);
// Compute expectations across all ops using this trajectory.
for (int j = 0; j < pauli_sums[i].size(); j++) {
int p_reps = (num_samples[i][j] + num_threads - 1) / num_threads;
if (run_samples[j] >= p_reps + rep_offset) {
continue;
}
float exp_v = 0.0;
NESTED_FN_STATUS_SYNC(
compute_status,
ComputeExpectationQsim(pauli_sums[i][j], sim, ss, sv, scratch,
&exp_v),
c_lock);
rolling_sums[j] += static_cast<double>(exp_v);
run_samples[j]++;
}
// Check if we have run enough trajectories for all ops.
bool break_loop = true;
for (int j = 0; j < num_samples[i].size(); j++) {
int p_reps = (num_samples[i][j] + num_threads - 1) / num_threads;
if (run_samples[j] < p_reps + rep_offset) {
break_loop = false;
break;
}
}
if (break_loop) {
// Lock writing to this batch index in output_tensor.
batch_locks[i].lock();
for (int j = 0; j < num_samples[i].size(); j++) {
rolling_sums[j] /= num_samples[i][j];
(*output_tensor)(i, j) += static_cast<float>(rolling_sums[j]);
}
batch_locks[i].unlock();
break;
}
}
}
};
// block_size = 1.
tensorflow::thread::ThreadPool::SchedulingParams scheduling_params(
tensorflow::thread::ThreadPool::SchedulingStrategy::kFixedBlockSize,
absl::nullopt, 1);
context->device()->tensorflow_cpu_worker_threads()->workers->ParallelFor(
num_threads, scheduling_params, DoWork);
OP_REQUIRES_OK(context, compute_status);
}