in src/beanmachine/graph/cavi.cpp [17:175]
void Graph::cavi(
uint num_iters,
uint steps_per_iter,
std::mt19937& gen,
uint elbo_samples) {
// convert the smart pointers in nodes to dumb pointers in node_ptrs
// for faster access
std::vector<Node*> node_ptrs;
// store all the sampled values for each node
std::vector<std::vector<NodeValue>> var_samples;
for (uint node_id = 0; node_id < static_cast<uint>(nodes.size()); node_id++) {
node_ptrs.push_back(nodes[node_id].get());
var_samples.push_back(std::vector<NodeValue>());
}
assert(node_ptrs.size() > 0); // keep linter happy
std::set<uint> supp = compute_support();
// the variational parameter probability for each node (initially 0.5)
std::vector<double> param_probability =
std::vector<double>(nodes.size(), 0.5);
assert(param_probability.size() > 0); // keep linter happy
// compute pool : nodes that we will infer over
// -> nodes to sample, nodes to eval, nodes to log_prob
// NOTE: we want the list of nodes in the pool to be sorted to ensure
// that we update the nodes in topological order. This helps in some models
// where some of the ancestor nodes have deterministic probabilities.
std::map<
uint,
std::tuple<std::vector<uint>, std::vector<uint>, std::vector<uint>>>
pool;
for (uint node_id : supp) {
Node* node = node_ptrs[node_id];
if (not node->is_observed) {
node->eval(gen); // evaluate the value of non-observed operator nodes
}
if (node->is_stochastic() and not node->is_observed) {
// sample some values for this node
auto& samples = var_samples[node_id];
std::bernoulli_distribution distrib(param_probability[node_id]);
for (uint step = 0; step < steps_per_iter; step++) {
samples.push_back(NodeValue(bool(distrib(gen))));
}
// For each node in the pool we need its stochastic descendants
// because those are the nodes for which we will compute the expected
// log_prob. We will call these nodes the log_prob_nodes.
std::vector<uint> det_desc;
std::vector<uint> logprob_nodes;
std::tie(det_desc, logprob_nodes) = compute_affected_nodes(node_id, supp);
// In order to compute the log_prob of these nodes we need to
// materialize their ancestors both deterministic and stochastic.
// The unobserved stochastic ancestors are to be sampled while the
// deterministic ancestors will be "eval"ed hence we call these nodes the
// sample_nodes and the eval_nodes respectively. Note: the unobserved
// logprob_nodes also need to be sampled excluding the target node.
// To avoid duplicates and to sort the nodes we will first create sets.
std::unordered_set<uint> sample_set;
// the deterministic nodes have to be evaluated in topological order
std::set<uint> det_set;
for (auto id : logprob_nodes) {
for (auto id2 : node_ptrs[id]->det_anc) {
det_set.insert(id2);
}
for (auto id2 : node_ptrs[id]->sto_anc) {
if (id2 != node_id and observed.find(id2) == observed.end()) {
sample_set.insert(id2);
}
}
if (id != node_id and observed.find(id) == observed.end()) {
sample_set.insert(id);
}
}
std::vector<uint> eval_nodes;
eval_nodes.insert(eval_nodes.end(), det_set.begin(), det_set.end());
std::vector<uint> sample_nodes;
sample_nodes.insert(
sample_nodes.end(), sample_set.begin(), sample_set.end());
pool[node_id] = std::make_tuple(sample_nodes, eval_nodes, logprob_nodes);
}
}
// optimization outer loop
for (uint inum = 0; inum < num_iters; inum++) {
for (auto it = pool.begin(); it != pool.end(); ++it) {
uint tgt_node_id = it->first;
Node* tgt_node = node_ptrs[tgt_node_id];
// the following dance of getting into a tuple is needed because this
// version of C++ doesn't have structured bindings
std::tuple<
const std::vector<uint>&,
const std::vector<uint>&,
const std::vector<uint>&>
tmp_tuple = it->second;
const std::vector<uint>& sample_nodes = std::get<0>(tmp_tuple);
const std::vector<uint>& eval_nodes = std::get<1>(tmp_tuple);
const std::vector<uint>& logprob_nodes = std::get<2>(tmp_tuple);
std::vector<double> expec(2, 0.0);
for (uint step = 0; step < steps_per_iter; step++) {
for (uint node_id : sample_nodes) {
node_ptrs[node_id]->value = var_samples[node_id][step];
}
for (uint val = 0; val < 2; val++) {
tgt_node->value = NodeValue(bool(val));
for (uint node_id : eval_nodes) {
node_ptrs[node_id]->eval(gen);
}
double log_prob = 0;
for (uint node_id : logprob_nodes) {
log_prob += node_ptrs[node_id]->log_prob();
}
// update the expectation w.r.t. current value of target node
expec[val] += log_prob / steps_per_iter;
}
}
if (std::isfinite(expec[0]) or std::isfinite(expec[1])) {
param_probability[tgt_node_id] = util::logistic(expec[1] - expec[0]);
} else {
param_probability[tgt_node_id] = 0.5;
}
auto& samples = var_samples[tgt_node_id];
std::bernoulli_distribution distrib(param_probability[tgt_node_id]);
for (uint step = 0; step < steps_per_iter; step++) {
samples[step] = NodeValue(bool(distrib(gen)));
}
}
if (elbo_samples > 0) {
// For a model p(X, Z) assume we are trying to estimate p(Z | X=x)
// using a variational approximation Q(Z).
// Now KL-Divergence of Q(Z) || p(Z|x) >= 0
// => E[log Q(Z) - log p(Z|x) | Z ~ Q] >= 0
// => p(x) >= E[log p(Z, x) - log Q(Z) | Z ~ Q]
// the RHS is the ELBO or evidence lower bound
// We compute this expectation using the samples of the nodes in our pool.
// log p(Z, x) is the log_prob of all the stochastic nodes
// and log Q(Z) is the log of the variational distribution for the pool.
double elbo = 0;
for (uint step = 0; step < elbo_samples; step++) {
for (uint node_id : supp) {
Node* node = node_ptrs[node_id];
if (node->is_stochastic()) {
if (not node->is_observed) {
double prob = param_probability[node_id];
std::bernoulli_distribution distrib(prob);
node->value = NodeValue(bool(distrib(gen)));
// subtract the log_prob of the variational distribution
elbo -= node->value._bool ? log(prob) : log(1 - prob);
}
// add the log_prob of the joint distribution
elbo += node->log_prob();
} else if (node->node_type == NodeType::OPERATOR) {
node->eval(gen);
}
}
}
elbo_vals.push_back(elbo / elbo_samples);
}
}
variational_params.clear();
for (uint node_id : queries) {
variational_params.push_back({param_probability[node_id]});
}
}